1// Copyright (c) 2016 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <memory>
16#include <unordered_map>
17#include <unordered_set>
18#include <utility>
19#include <vector>
20
21#include "gmock/gmock.h"
22#include "gtest/gtest.h"
23#include "source/opt/build_module.h"
24#include "source/opt/def_use_manager.h"
25#include "source/opt/ir_context.h"
26#include "source/opt/module.h"
27#include "spirv-tools/libspirv.hpp"
28#include "test/opt/pass_fixture.h"
29#include "test/opt/pass_utils.h"
30
31namespace spvtools {
32namespace opt {
33namespace analysis {
34namespace {
35
36using ::testing::Contains;
37using ::testing::UnorderedElementsAre;
38using ::testing::UnorderedElementsAreArray;
39
40// Returns the number of uses of |id|.
41uint32_t NumUses(const std::unique_ptr<IRContext>& context, uint32_t id) {
42  uint32_t count = 0;
43  context->get_def_use_mgr()->ForEachUse(
44      id, [&count](Instruction*, uint32_t) { ++count; });
45  return count;
46}
47
48// Returns the opcode of each use of |id|.
49//
50// If |id| is used multiple times in a single instruction, that instruction's
51// opcode will appear a corresponding number of times.
52std::vector<spv::Op> GetUseOpcodes(const std::unique_ptr<IRContext>& context,
53                                   uint32_t id) {
54  std::vector<spv::Op> opcodes;
55  context->get_def_use_mgr()->ForEachUse(
56      id, [&opcodes](Instruction* user, uint32_t) {
57        opcodes.push_back(user->opcode());
58      });
59  return opcodes;
60}
61
62// Disassembles the given |inst| and returns the disassembly.
63std::string DisassembleInst(Instruction* inst) {
64  SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
65
66  std::vector<uint32_t> binary;
67  // We need this to generate the necessary header in the binary.
68  tools.Assemble("", &binary);
69  inst->ToBinaryWithoutAttachedDebugInsts(&binary);
70
71  std::string text;
72  // We'll need to check the underlying id numbers.
73  // So turn off friendly names for ids.
74  tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
75  while (!text.empty() && text.back() == '\n') text.pop_back();
76  return text;
77}
78
79// A struct for holding expected id defs and uses.
80struct InstDefUse {
81  using IdInstPair = std::pair<uint32_t, std::string>;
82  using IdInstsPair = std::pair<uint32_t, std::vector<std::string>>;
83
84  // Ids and their corresponding def instructions.
85  std::vector<IdInstPair> defs;
86  // Ids and their corresponding use instructions.
87  std::vector<IdInstsPair> uses;
88};
89
90// Checks that the |actual_defs| and |actual_uses| are in accord with
91// |expected_defs_uses|.
92void CheckDef(const InstDefUse& expected_defs_uses,
93              const DefUseManager::IdToDefMap& actual_defs) {
94  // Check defs.
95  ASSERT_EQ(expected_defs_uses.defs.size(), actual_defs.size());
96  for (uint32_t i = 0; i < expected_defs_uses.defs.size(); ++i) {
97    const auto id = expected_defs_uses.defs[i].first;
98    const auto expected_def = expected_defs_uses.defs[i].second;
99    ASSERT_EQ(1u, actual_defs.count(id)) << "expected to def id [" << id << "]";
100    auto def = actual_defs.at(id);
101    if (def->opcode() != spv::Op::OpConstant) {
102      // Constants don't disassemble properly without a full context.
103      EXPECT_EQ(expected_def, DisassembleInst(actual_defs.at(id)));
104    }
105  }
106}
107
108using UserMap = std::unordered_map<uint32_t, std::vector<Instruction*>>;
109
110// Creates a mapping of all definitions to their users (except OpConstant).
111//
112// OpConstants are skipped because they cannot be disassembled in isolation.
113UserMap BuildAllUsers(const DefUseManager* mgr, uint32_t idBound) {
114  UserMap userMap;
115  for (uint32_t id = 0; id != idBound; ++id) {
116    if (mgr->GetDef(id)) {
117      mgr->ForEachUser(id, [id, &userMap](Instruction* user) {
118        if (user->opcode() != spv::Op::OpConstant) {
119          userMap[id].push_back(user);
120        }
121      });
122    }
123  }
124  return userMap;
125}
126
127// Constants don't disassemble properly without a full context, so skip them as
128// checks.
129void CheckUse(const InstDefUse& expected_defs_uses, const DefUseManager* mgr,
130              uint32_t idBound) {
131  UserMap actual_uses = BuildAllUsers(mgr, idBound);
132  // Check uses.
133  ASSERT_EQ(expected_defs_uses.uses.size(), actual_uses.size());
134  for (uint32_t i = 0; i < expected_defs_uses.uses.size(); ++i) {
135    const auto id = expected_defs_uses.uses[i].first;
136    const auto& expected_uses = expected_defs_uses.uses[i].second;
137
138    ASSERT_EQ(1u, actual_uses.count(id)) << "expected to use id [" << id << "]";
139    const auto& uses = actual_uses.at(id);
140
141    ASSERT_EQ(expected_uses.size(), uses.size())
142        << "id [" << id << "] # uses: expected: " << expected_uses.size()
143        << " actual: " << uses.size();
144
145    std::vector<std::string> actual_uses_disassembled;
146    for (const auto actual_use : uses) {
147      actual_uses_disassembled.emplace_back(DisassembleInst(actual_use));
148    }
149    EXPECT_THAT(actual_uses_disassembled,
150                UnorderedElementsAreArray(expected_uses));
151  }
152}
153
154// The following test case mimics how LLVM handles induction variables.
155// But, yeah, it's not very readable. However, we only care about the id
156// defs and uses. So, no need to make sure this is valid OpPhi construct.
157const char kOpPhiTestFunction[] =
158    " %1 = OpTypeVoid "
159    " %6 = OpTypeInt 32 0 "
160    "%10 = OpTypeFloat 32 "
161    "%16 = OpTypeBool "
162    " %3 = OpTypeFunction %1 "
163    " %8 = OpConstant %6 0 "
164    "%18 = OpConstant %6 1 "
165    "%12 = OpConstant %10 1.0 "
166    " %2 = OpFunction %1 None %3 "
167    " %4 = OpLabel "
168    "      OpBranch %5 "
169
170    " %5 = OpLabel "
171    " %7 = OpPhi %6 %8 %4 %9 %5 "
172    "%11 = OpPhi %10 %12 %4 %13 %5 "
173    " %9 = OpIAdd %6 %7 %8 "
174    "%13 = OpFAdd %10 %11 %12 "
175    "%17 = OpSLessThan %16 %7 %18 "
176    "      OpLoopMerge %19 %5 None "
177    "      OpBranchConditional %17 %5 %19 "
178
179    "%19 = OpLabel "
180    "      OpReturn "
181    "      OpFunctionEnd";
182
183struct ParseDefUseCase {
184  const char* text;
185  InstDefUse du;
186};
187
188using ParseDefUseTest = ::testing::TestWithParam<ParseDefUseCase>;
189
190TEST_P(ParseDefUseTest, Case) {
191  const auto& tc = GetParam();
192
193  // Build module.
194  const std::vector<const char*> text = {tc.text};
195  std::unique_ptr<IRContext> context =
196      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
197                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
198  ASSERT_NE(nullptr, context);
199
200  // Analyze def and use.
201  DefUseManager manager(context->module());
202
203  CheckDef(tc.du, manager.id_to_defs());
204  CheckUse(tc.du, &manager, context->module()->IdBound());
205}
206
207// clang-format off
208INSTANTIATE_TEST_SUITE_P(
209    TestCase, ParseDefUseTest,
210    ::testing::ValuesIn(std::vector<ParseDefUseCase>{
211        {"", {{}, {}}},                              // no instruction
212        {"OpMemoryModel Logical GLSL450", {{}, {}}}, // no def and use
213        { // single def, no use
214          "%1 = OpString \"wow\"",
215          {
216            {{1, "%1 = OpString \"wow\""}}, // defs
217            {}                              // uses
218          }
219        },
220        { // multiple def, no use
221          "%1 = OpString \"hello\" "
222          "%2 = OpString \"world\" "
223          "%3 = OpTypeVoid",
224          {
225            {  // defs
226              {1, "%1 = OpString \"hello\""},
227              {2, "%2 = OpString \"world\""},
228              {3, "%3 = OpTypeVoid"},
229            },
230            {} // uses
231          }
232        },
233        { // multiple def, multiple use
234          "%1 = OpTypeBool "
235          "%2 = OpTypeVector %1 3 "
236          "%3 = OpTypeMatrix %2 3",
237          {
238            { // defs
239              {1, "%1 = OpTypeBool"},
240              {2, "%2 = OpTypeVector %1 3"},
241              {3, "%3 = OpTypeMatrix %2 3"},
242            },
243            { // uses
244              {1, {"%2 = OpTypeVector %1 3"}},
245              {2, {"%3 = OpTypeMatrix %2 3"}},
246            }
247          }
248        },
249        { // multiple use of the same id
250          "%1 = OpTypeBool "
251          "%2 = OpTypeVector %1 2 "
252          "%3 = OpTypeVector %1 3 "
253          "%4 = OpTypeVector %1 4",
254          {
255            { // defs
256              {1, "%1 = OpTypeBool"},
257              {2, "%2 = OpTypeVector %1 2"},
258              {3, "%3 = OpTypeVector %1 3"},
259              {4, "%4 = OpTypeVector %1 4"},
260            },
261            { // uses
262              {1,
263                {
264                  "%2 = OpTypeVector %1 2",
265                  "%3 = OpTypeVector %1 3",
266                  "%4 = OpTypeVector %1 4",
267                }
268              },
269            }
270          }
271        },
272        { // labels
273          "%1 = OpTypeVoid "
274          "%2 = OpTypeBool "
275          "%3 = OpTypeFunction %1 "
276          "%4 = OpConstantTrue %2 "
277          "%5 = OpFunction %1 None %3 "
278
279          "%6 = OpLabel "
280          "OpBranchConditional %4 %7 %8 "
281
282          "%7 = OpLabel "
283          "OpBranch %7 "
284
285          "%8 = OpLabel "
286          "OpReturn "
287
288          "OpFunctionEnd",
289          {
290            { // defs
291              {1, "%1 = OpTypeVoid"},
292              {2, "%2 = OpTypeBool"},
293              {3, "%3 = OpTypeFunction %1"},
294              {4, "%4 = OpConstantTrue %2"},
295              {5, "%5 = OpFunction %1 None %3"},
296              {6, "%6 = OpLabel"},
297              {7, "%7 = OpLabel"},
298              {8, "%8 = OpLabel"},
299            },
300            { // uses
301              {1, {
302                    "%3 = OpTypeFunction %1",
303                    "%5 = OpFunction %1 None %3",
304                  }
305              },
306              {2, {"%4 = OpConstantTrue %2"}},
307              {3, {"%5 = OpFunction %1 None %3"}},
308              {4, {"OpBranchConditional %4 %7 %8"}},
309              {7,
310                {
311                  "OpBranchConditional %4 %7 %8",
312                  "OpBranch %7",
313                }
314              },
315              {8, {"OpBranchConditional %4 %7 %8"}},
316            }
317          }
318        },
319        { // cross function
320          "%1 = OpTypeBool "
321          "%3 = OpTypeFunction %1 "
322          "%2 = OpFunction %1 None %3 "
323
324          "%4 = OpLabel "
325          "%5 = OpVariable %1 Function "
326          "%6 = OpFunctionCall %1 %2 %5 "
327          "OpReturnValue %6 "
328
329          "OpFunctionEnd",
330          {
331            { // defs
332              {1, "%1 = OpTypeBool"},
333              {2, "%2 = OpFunction %1 None %3"},
334              {3, "%3 = OpTypeFunction %1"},
335              {4, "%4 = OpLabel"},
336              {5, "%5 = OpVariable %1 Function"},
337              {6, "%6 = OpFunctionCall %1 %2 %5"},
338            },
339            { // uses
340              {1,
341                {
342                  "%2 = OpFunction %1 None %3",
343                  "%3 = OpTypeFunction %1",
344                  "%5 = OpVariable %1 Function",
345                  "%6 = OpFunctionCall %1 %2 %5",
346                }
347              },
348              {2, {"%6 = OpFunctionCall %1 %2 %5"}},
349              {3, {"%2 = OpFunction %1 None %3"}},
350              {5, {"%6 = OpFunctionCall %1 %2 %5"}},
351              {6, {"OpReturnValue %6"}},
352            }
353          }
354        },
355        { // selection merge and loop merge
356          "%1 = OpTypeVoid "
357          "%3 = OpTypeFunction %1 "
358          "%10 = OpTypeBool "
359          "%8 = OpConstantTrue %10 "
360          "%2 = OpFunction %1 None %3 "
361
362          "%4 = OpLabel "
363          "OpLoopMerge %5 %4 None "
364          "OpBranch %6 "
365
366          "%5 = OpLabel "
367          "OpReturn "
368
369          "%6 = OpLabel "
370          "OpSelectionMerge %7 None "
371          "OpBranchConditional %8 %9 %7 "
372
373          "%7 = OpLabel "
374          "OpReturn "
375
376          "%9 = OpLabel "
377          "OpReturn "
378
379          "OpFunctionEnd",
380          {
381            { // defs
382              {1, "%1 = OpTypeVoid"},
383              {2, "%2 = OpFunction %1 None %3"},
384              {3, "%3 = OpTypeFunction %1"},
385              {4, "%4 = OpLabel"},
386              {5, "%5 = OpLabel"},
387              {6, "%6 = OpLabel"},
388              {7, "%7 = OpLabel"},
389              {8, "%8 = OpConstantTrue %10"},
390              {9, "%9 = OpLabel"},
391              {10, "%10 = OpTypeBool"},
392            },
393            { // uses
394              {1,
395                {
396                  "%2 = OpFunction %1 None %3",
397                  "%3 = OpTypeFunction %1",
398                }
399              },
400              {3, {"%2 = OpFunction %1 None %3"}},
401              {4, {"OpLoopMerge %5 %4 None"}},
402              {5, {"OpLoopMerge %5 %4 None"}},
403              {6, {"OpBranch %6"}},
404              {7,
405                {
406                  "OpSelectionMerge %7 None",
407                  "OpBranchConditional %8 %9 %7",
408                }
409              },
410              {8, {"OpBranchConditional %8 %9 %7"}},
411              {9, {"OpBranchConditional %8 %9 %7"}},
412              {10, {"%8 = OpConstantTrue %10"}},
413            }
414          }
415        },
416        { // Forward reference
417          "OpDecorate %1 Block "
418          "OpTypeForwardPointer %2 Input "
419          "%3 = OpTypeInt 32 0 "
420          "%1 = OpTypeStruct %3 "
421          "%2 = OpTypePointer Input %3",
422          {
423            { // defs
424              {1, "%1 = OpTypeStruct %3"},
425              {2, "%2 = OpTypePointer Input %3"},
426              {3, "%3 = OpTypeInt 32 0"},
427            },
428            { // uses
429              {1, {"OpDecorate %1 Block"}},
430              {2, {"OpTypeForwardPointer %2 Input"}},
431              {3,
432                {
433                  "%1 = OpTypeStruct %3",
434                  "%2 = OpTypePointer Input %3",
435                }
436              }
437            },
438          },
439        },
440        { // OpPhi
441          kOpPhiTestFunction,
442          {
443            { // defs
444              {1, "%1 = OpTypeVoid"},
445              {2, "%2 = OpFunction %1 None %3"},
446              {3, "%3 = OpTypeFunction %1"},
447              {4, "%4 = OpLabel"},
448              {5, "%5 = OpLabel"},
449              {6, "%6 = OpTypeInt 32 0"},
450              {7, "%7 = OpPhi %6 %8 %4 %9 %5"},
451              {8, "%8 = OpConstant %6 0"},
452              {9, "%9 = OpIAdd %6 %7 %8"},
453              {10, "%10 = OpTypeFloat 32"},
454              {11, "%11 = OpPhi %10 %12 %4 %13 %5"},
455              {12, "%12 = OpConstant %10 1.0"},
456              {13, "%13 = OpFAdd %10 %11 %12"},
457              {16, "%16 = OpTypeBool"},
458              {17, "%17 = OpSLessThan %16 %7 %18"},
459              {18, "%18 = OpConstant %6 1"},
460              {19, "%19 = OpLabel"},
461            },
462            { // uses
463              {1,
464                {
465                  "%2 = OpFunction %1 None %3",
466                  "%3 = OpTypeFunction %1",
467                }
468              },
469              {3, {"%2 = OpFunction %1 None %3"}},
470              {4,
471                {
472                  "%7 = OpPhi %6 %8 %4 %9 %5",
473                  "%11 = OpPhi %10 %12 %4 %13 %5",
474                }
475              },
476              {5,
477                {
478                  "OpBranch %5",
479                  "%7 = OpPhi %6 %8 %4 %9 %5",
480                  "%11 = OpPhi %10 %12 %4 %13 %5",
481                  "OpLoopMerge %19 %5 None",
482                  "OpBranchConditional %17 %5 %19",
483                }
484              },
485              {6,
486                {
487                  // Can't check constants properly
488                  // "%8 = OpConstant %6 0",
489                  // "%18 = OpConstant %6 1",
490                  "%7 = OpPhi %6 %8 %4 %9 %5",
491                  "%9 = OpIAdd %6 %7 %8",
492                }
493              },
494              {7,
495                {
496                  "%9 = OpIAdd %6 %7 %8",
497                  "%17 = OpSLessThan %16 %7 %18",
498                }
499              },
500              {8,
501                {
502                  "%7 = OpPhi %6 %8 %4 %9 %5",
503                  "%9 = OpIAdd %6 %7 %8",
504                }
505              },
506              {9, {"%7 = OpPhi %6 %8 %4 %9 %5"}},
507              {10,
508                {
509                  // "%12 = OpConstant %10 1.0",
510                  "%11 = OpPhi %10 %12 %4 %13 %5",
511                  "%13 = OpFAdd %10 %11 %12",
512                }
513              },
514              {11, {"%13 = OpFAdd %10 %11 %12"}},
515              {12,
516                {
517                  "%11 = OpPhi %10 %12 %4 %13 %5",
518                  "%13 = OpFAdd %10 %11 %12",
519                }
520              },
521              {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}},
522              {16, {"%17 = OpSLessThan %16 %7 %18"}},
523              {17, {"OpBranchConditional %17 %5 %19"}},
524              {18, {"%17 = OpSLessThan %16 %7 %18"}},
525              {19,
526                {
527                  "OpLoopMerge %19 %5 None",
528                  "OpBranchConditional %17 %5 %19",
529                }
530              },
531            },
532          },
533        },
534        { // OpPhi defining and referencing the same id.
535          "%1 = OpTypeBool "
536          "%3 = OpTypeFunction %1 "
537          "%2 = OpConstantTrue %1 "
538          "%4 = OpFunction %1 None %3 "
539          "%6 = OpLabel "
540          "     OpBranch %7 "
541          "%7 = OpLabel "
542          "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
543          "     OpBranch %7 "
544          "     OpFunctionEnd",
545          {
546            { // defs
547              {1, "%1 = OpTypeBool"},
548              {2, "%2 = OpConstantTrue %1"},
549              {3, "%3 = OpTypeFunction %1"},
550              {4, "%4 = OpFunction %1 None %3"},
551              {6, "%6 = OpLabel"},
552              {7, "%7 = OpLabel"},
553              {8, "%8 = OpPhi %1 %8 %7 %2 %6"},
554            },
555            { // uses
556              {1,
557                {
558                  "%2 = OpConstantTrue %1",
559                  "%3 = OpTypeFunction %1",
560                  "%4 = OpFunction %1 None %3",
561                  "%8 = OpPhi %1 %8 %7 %2 %6",
562                }
563              },
564              {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
565              {3, {"%4 = OpFunction %1 None %3"}},
566              {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
567              {7,
568                {
569                  "OpBranch %7",
570                  "%8 = OpPhi %1 %8 %7 %2 %6",
571                  "OpBranch %7",
572                }
573              },
574              {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
575            },
576          },
577        },
578    })
579);
580// clang-format on
581
582struct ReplaceUseCase {
583  const char* before;
584  std::vector<std::pair<uint32_t, uint32_t>> candidates;
585  const char* after;
586  InstDefUse du;
587};
588
589using ReplaceUseTest = ::testing::TestWithParam<ReplaceUseCase>;
590
591// Disassembles the given |module| and returns the disassembly.
592std::string DisassembleModule(Module* module) {
593  SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
594
595  std::vector<uint32_t> binary;
596  module->ToBinary(&binary, /* skip_nop = */ false);
597
598  std::string text;
599  // We'll need to check the underlying id numbers.
600  // So turn off friendly names for ids.
601  tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
602  while (!text.empty() && text.back() == '\n') text.pop_back();
603  return text;
604}
605
606TEST_P(ReplaceUseTest, Case) {
607  const auto& tc = GetParam();
608
609  // Build module.
610  const std::vector<const char*> text = {tc.before};
611  std::unique_ptr<IRContext> context =
612      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
613                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
614  ASSERT_NE(nullptr, context);
615
616  // Force a re-build of def-use manager.
617  context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
618  (void)context->get_def_use_mgr();
619
620  // Do the substitution.
621  for (const auto& candidate : tc.candidates) {
622    context->ReplaceAllUsesWith(candidate.first, candidate.second);
623  }
624
625  EXPECT_EQ(tc.after, DisassembleModule(context->module()));
626  CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
627  CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound());
628}
629
630// clang-format off
631INSTANTIATE_TEST_SUITE_P(
632    TestCase, ReplaceUseTest,
633    ::testing::ValuesIn(std::vector<ReplaceUseCase>{
634      { // no use, no replace request
635        "", {}, "", {},
636      },
637      { // replace one use
638        "%1 = OpTypeBool "
639        "%2 = OpTypeVector %1 3 "
640        "%3 = OpTypeInt 32 0 ",
641        {{1, 3}},
642        "%1 = OpTypeBool\n"
643        "%2 = OpTypeVector %3 3\n"
644        "%3 = OpTypeInt 32 0",
645        {
646          { // defs
647            {1, "%1 = OpTypeBool"},
648            {2, "%2 = OpTypeVector %3 3"},
649            {3, "%3 = OpTypeInt 32 0"},
650          },
651          { // uses
652            {3, {"%2 = OpTypeVector %3 3"}},
653          },
654        },
655      },
656      { // replace and then replace back
657        "%1 = OpTypeBool "
658        "%2 = OpTypeVector %1 3 "
659        "%3 = OpTypeInt 32 0",
660        {{1, 3}, {3, 1}},
661        "%1 = OpTypeBool\n"
662        "%2 = OpTypeVector %1 3\n"
663        "%3 = OpTypeInt 32 0",
664        {
665          { // defs
666            {1, "%1 = OpTypeBool"},
667            {2, "%2 = OpTypeVector %1 3"},
668            {3, "%3 = OpTypeInt 32 0"},
669          },
670          { // uses
671            {1, {"%2 = OpTypeVector %1 3"}},
672          },
673        },
674      },
675      { // replace with the same id
676        "%1 = OpTypeBool "
677        "%2 = OpTypeVector %1 3",
678        {{1, 1}, {2, 2}, {3, 3}},
679        "%1 = OpTypeBool\n"
680        "%2 = OpTypeVector %1 3",
681        {
682          { // defs
683            {1, "%1 = OpTypeBool"},
684            {2, "%2 = OpTypeVector %1 3"},
685          },
686          { // uses
687            {1, {"%2 = OpTypeVector %1 3"}},
688          },
689        },
690      },
691      { // replace in sequence
692        "%1 = OpTypeBool "
693        "%2 = OpTypeVector %1 3 "
694        "%3 = OpTypeInt 32 0 "
695        "%4 = OpTypeInt 32 1 ",
696        {{1, 3}, {3, 4}},
697        "%1 = OpTypeBool\n"
698        "%2 = OpTypeVector %4 3\n"
699        "%3 = OpTypeInt 32 0\n"
700        "%4 = OpTypeInt 32 1",
701        {
702          { // defs
703            {1, "%1 = OpTypeBool"},
704            {2, "%2 = OpTypeVector %4 3"},
705            {3, "%3 = OpTypeInt 32 0"},
706            {4, "%4 = OpTypeInt 32 1"},
707          },
708          { // uses
709            {4, {"%2 = OpTypeVector %4 3"}},
710          },
711        },
712      },
713      { // replace multiple uses
714        "%1 = OpTypeBool "
715        "%2 = OpTypeVector %1 2 "
716        "%3 = OpTypeVector %1 3 "
717        "%4 = OpTypeVector %1 4 "
718        "%5 = OpTypeMatrix %2 2 "
719        "%6 = OpTypeMatrix %3 3 "
720        "%7 = OpTypeMatrix %4 4 "
721        "%8 = OpTypeInt 32 0 "
722        "%9 = OpTypeInt 32 1 "
723        "%10 = OpTypeInt 64 0",
724        {{1, 8}, {2, 9}, {4, 10}},
725        "%1 = OpTypeBool\n"
726        "%2 = OpTypeVector %8 2\n"
727        "%3 = OpTypeVector %8 3\n"
728        "%4 = OpTypeVector %8 4\n"
729        "%5 = OpTypeMatrix %9 2\n"
730        "%6 = OpTypeMatrix %3 3\n"
731        "%7 = OpTypeMatrix %10 4\n"
732        "%8 = OpTypeInt 32 0\n"
733        "%9 = OpTypeInt 32 1\n"
734        "%10 = OpTypeInt 64 0",
735        {
736          { // defs
737            {1, "%1 = OpTypeBool"},
738            {2, "%2 = OpTypeVector %8 2"},
739            {3, "%3 = OpTypeVector %8 3"},
740            {4, "%4 = OpTypeVector %8 4"},
741            {5, "%5 = OpTypeMatrix %9 2"},
742            {6, "%6 = OpTypeMatrix %3 3"},
743            {7, "%7 = OpTypeMatrix %10 4"},
744            {8, "%8 = OpTypeInt 32 0"},
745            {9, "%9 = OpTypeInt 32 1"},
746            {10, "%10 = OpTypeInt 64 0"},
747          },
748          { // uses
749            {8,
750              {
751                "%2 = OpTypeVector %8 2",
752                "%3 = OpTypeVector %8 3",
753                "%4 = OpTypeVector %8 4",
754              }
755            },
756            {9, {"%5 = OpTypeMatrix %9 2"}},
757            {3, {"%6 = OpTypeMatrix %3 3"}},
758            {10, {"%7 = OpTypeMatrix %10 4"}},
759          },
760        },
761      },
762      { // OpPhi.
763        kOpPhiTestFunction,
764        // replace one id used by OpPhi, replace one id generated by OpPhi
765        {{9, 13}, {11, 9}},
766         "%1 = OpTypeVoid\n"
767         "%6 = OpTypeInt 32 0\n"
768         "%10 = OpTypeFloat 32\n"
769         "%16 = OpTypeBool\n"
770         "%3 = OpTypeFunction %1\n"
771         "%8 = OpConstant %6 0\n"
772         "%18 = OpConstant %6 1\n"
773         "%12 = OpConstant %10 1\n"
774         "%2 = OpFunction %1 None %3\n"
775         "%4 = OpLabel\n"
776               "OpBranch %5\n"
777
778         "%5 = OpLabel\n"
779         "%7 = OpPhi %6 %8 %4 %13 %5\n" // %9 -> %13
780        "%11 = OpPhi %10 %12 %4 %13 %5\n"
781         "%9 = OpIAdd %6 %7 %8\n"
782        "%13 = OpFAdd %10 %9 %12\n"       // %11 -> %9
783        "%17 = OpSLessThan %16 %7 %18\n"
784              "OpLoopMerge %19 %5 None\n"
785              "OpBranchConditional %17 %5 %19\n"
786
787        "%19 = OpLabel\n"
788              "OpReturn\n"
789              "OpFunctionEnd",
790        {
791          { // defs.
792            {1, "%1 = OpTypeVoid"},
793            {2, "%2 = OpFunction %1 None %3"},
794            {3, "%3 = OpTypeFunction %1"},
795            {4, "%4 = OpLabel"},
796            {5, "%5 = OpLabel"},
797            {6, "%6 = OpTypeInt 32 0"},
798            {7, "%7 = OpPhi %6 %8 %4 %13 %5"},
799            {8, "%8 = OpConstant %6 0"},
800            {9, "%9 = OpIAdd %6 %7 %8"},
801            {10, "%10 = OpTypeFloat 32"},
802            {11, "%11 = OpPhi %10 %12 %4 %13 %5"},
803            {12, "%12 = OpConstant %10 1.0"},
804            {13, "%13 = OpFAdd %10 %9 %12"},
805            {16, "%16 = OpTypeBool"},
806            {17, "%17 = OpSLessThan %16 %7 %18"},
807            {18, "%18 = OpConstant %6 1"},
808            {19, "%19 = OpLabel"},
809          },
810          { // uses
811            {1,
812              {
813                "%2 = OpFunction %1 None %3",
814                "%3 = OpTypeFunction %1",
815              }
816            },
817            {3, {"%2 = OpFunction %1 None %3"}},
818            {4,
819              {
820                "%7 = OpPhi %6 %8 %4 %13 %5",
821                "%11 = OpPhi %10 %12 %4 %13 %5",
822              }
823            },
824            {5,
825              {
826                "OpBranch %5",
827                "%7 = OpPhi %6 %8 %4 %13 %5",
828                "%11 = OpPhi %10 %12 %4 %13 %5",
829                "OpLoopMerge %19 %5 None",
830                "OpBranchConditional %17 %5 %19",
831              }
832            },
833            {6,
834              {
835                // Can't properly check constants
836                // "%8 = OpConstant %6 0",
837                // "%18 = OpConstant %6 1",
838                "%7 = OpPhi %6 %8 %4 %13 %5",
839                "%9 = OpIAdd %6 %7 %8"
840              }
841            },
842            {7,
843              {
844                "%9 = OpIAdd %6 %7 %8",
845                "%17 = OpSLessThan %16 %7 %18",
846              }
847            },
848            {8,
849              {
850                "%7 = OpPhi %6 %8 %4 %13 %5",
851                "%9 = OpIAdd %6 %7 %8",
852              }
853            },
854            {9, {"%13 = OpFAdd %10 %9 %12"}}, // uses of %9 changed from %7 to %13
855            {10,
856              {
857                "%11 = OpPhi %10 %12 %4 %13 %5",
858                // "%12 = OpConstant %10 1",
859                "%13 = OpFAdd %10 %9 %12"
860              }
861            },
862            // no more uses of %11
863            {12,
864              {
865                "%11 = OpPhi %10 %12 %4 %13 %5",
866                "%13 = OpFAdd %10 %9 %12"
867              }
868            },
869            {13, {
870                   "%7 = OpPhi %6 %8 %4 %13 %5",
871                   "%11 = OpPhi %10 %12 %4 %13 %5",
872                 }
873            },
874            {16, {"%17 = OpSLessThan %16 %7 %18"}},
875            {17, {"OpBranchConditional %17 %5 %19"}},
876            {18, {"%17 = OpSLessThan %16 %7 %18"}},
877            {19,
878              {
879                "OpLoopMerge %19 %5 None",
880                "OpBranchConditional %17 %5 %19",
881              }
882            },
883          },
884        },
885      },
886      { // OpPhi defining and referencing the same id.
887        "%1 = OpTypeBool "
888        "%3 = OpTypeFunction %1 "
889        "%2 = OpConstantTrue %1 "
890
891        "%4 = OpFunction %3 None %1 "
892        "%6 = OpLabel "
893        "     OpBranch %7 "
894        "%7 = OpLabel "
895        "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
896        "     OpBranch %7 "
897        "     OpFunctionEnd",
898        {{8, 2}},
899        "%1 = OpTypeBool\n"
900        "%3 = OpTypeFunction %1\n"
901        "%2 = OpConstantTrue %1\n"
902
903        "%4 = OpFunction %3 None %1\n"
904        "%6 = OpLabel\n"
905             "OpBranch %7\n"
906        "%7 = OpLabel\n"
907        "%8 = OpPhi %1 %2 %7 %2 %6\n" // use of %8 changed to %2
908             "OpBranch %7\n"
909             "OpFunctionEnd",
910        {
911          { // defs
912            {1, "%1 = OpTypeBool"},
913            {2, "%2 = OpConstantTrue %1"},
914            {3, "%3 = OpTypeFunction %1"},
915            {4, "%4 = OpFunction %3 None %1"},
916            {6, "%6 = OpLabel"},
917            {7, "%7 = OpLabel"},
918            {8, "%8 = OpPhi %1 %2 %7 %2 %6"},
919          },
920          { // uses
921            {1,
922              {
923                "%2 = OpConstantTrue %1",
924                "%3 = OpTypeFunction %1",
925                "%4 = OpFunction %3 None %1",
926                "%8 = OpPhi %1 %2 %7 %2 %6",
927              }
928            },
929            {2,
930              {
931                // Only checking users
932                "%8 = OpPhi %1 %2 %7 %2 %6",
933              }
934            },
935            {3, {"%4 = OpFunction %3 None %1"}},
936            {6, {"%8 = OpPhi %1 %2 %7 %2 %6"}},
937            {7,
938              {
939                "OpBranch %7",
940                "%8 = OpPhi %1 %2 %7 %2 %6",
941                "OpBranch %7",
942              }
943            },
944            // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
945          },
946        },
947      },
948    })
949);
950// clang-format on
951
952struct KillDefCase {
953  const char* before;
954  std::vector<uint32_t> ids_to_kill;
955  const char* after;
956  InstDefUse du;
957};
958
959using KillDefTest = ::testing::TestWithParam<KillDefCase>;
960
961TEST_P(KillDefTest, Case) {
962  const auto& tc = GetParam();
963
964  // Build module.
965  const std::vector<const char*> text = {tc.before};
966  std::unique_ptr<IRContext> context =
967      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
968                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
969  ASSERT_NE(nullptr, context);
970
971  // Analyze def and use.
972  DefUseManager manager(context->module());
973
974  // Do the substitution.
975  for (const auto id : tc.ids_to_kill) context->KillDef(id);
976
977  EXPECT_EQ(tc.after, DisassembleModule(context->module()));
978  CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
979  CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound());
980}
981
982// clang-format off
983INSTANTIATE_TEST_SUITE_P(
984    TestCase, KillDefTest,
985    ::testing::ValuesIn(std::vector<KillDefCase>{
986      { // no def, no use, no kill
987        "", {}, "", {}
988      },
989      { // kill nothing
990        "%1 = OpTypeBool "
991        "%2 = OpTypeVector %1 2 "
992        "%3 = OpTypeVector %1 3 ",
993        {},
994        "%1 = OpTypeBool\n"
995        "%2 = OpTypeVector %1 2\n"
996        "%3 = OpTypeVector %1 3",
997        {
998          { // defs
999            {1, "%1 = OpTypeBool"},
1000            {2, "%2 = OpTypeVector %1 2"},
1001            {3, "%3 = OpTypeVector %1 3"},
1002          },
1003          { // uses
1004            {1,
1005              {
1006                "%2 = OpTypeVector %1 2",
1007                "%3 = OpTypeVector %1 3",
1008              }
1009            },
1010          },
1011        },
1012      },
1013      { // kill id used, kill id not used, kill id not defined
1014        "%1 = OpTypeBool "
1015        "%2 = OpTypeVector %1 2 "
1016        "%3 = OpTypeVector %1 3 "
1017        "%4 = OpTypeVector %1 4 "
1018        "%5 = OpTypeMatrix %3 3 "
1019        "%6 = OpTypeMatrix %2 3",
1020        {1, 3, 5, 10}, // ids to kill
1021        "%2 = OpTypeVector %1 2\n"
1022        "%4 = OpTypeVector %1 4\n"
1023        "%6 = OpTypeMatrix %2 3",
1024        {
1025          { // defs
1026            {2, "%2 = OpTypeVector %1 2"},
1027            {4, "%4 = OpTypeVector %1 4"},
1028            {6, "%6 = OpTypeMatrix %2 3"},
1029          },
1030          { // uses. %1 and %3 are both killed, so no uses
1031            // recorded for them anymore.
1032            {2, {"%6 = OpTypeMatrix %2 3"}},
1033          }
1034        },
1035      },
1036      { // OpPhi.
1037        kOpPhiTestFunction,
1038        {9, 11}, // kill one id used by OpPhi, kill one id generated by OpPhi
1039         "%1 = OpTypeVoid\n"
1040         "%6 = OpTypeInt 32 0\n"
1041         "%10 = OpTypeFloat 32\n"
1042         "%16 = OpTypeBool\n"
1043         "%3 = OpTypeFunction %1\n"
1044         "%8 = OpConstant %6 0\n"
1045         "%18 = OpConstant %6 1\n"
1046         "%12 = OpConstant %10 1\n"
1047         "%2 = OpFunction %1 None %3\n"
1048         "%4 = OpLabel\n"
1049               "OpBranch %5\n"
1050
1051         "%5 = OpLabel\n"
1052         "%7 = OpPhi %6 %8 %4 %9 %5\n"
1053        "%13 = OpFAdd %10 %11 %12\n"
1054        "%17 = OpSLessThan %16 %7 %18\n"
1055              "OpLoopMerge %19 %5 None\n"
1056              "OpBranchConditional %17 %5 %19\n"
1057
1058        "%19 = OpLabel\n"
1059              "OpReturn\n"
1060              "OpFunctionEnd",
1061        {
1062          { // defs. %9 & %11 are killed.
1063            {1, "%1 = OpTypeVoid"},
1064            {2, "%2 = OpFunction %1 None %3"},
1065            {3, "%3 = OpTypeFunction %1"},
1066            {4, "%4 = OpLabel"},
1067            {5, "%5 = OpLabel"},
1068            {6, "%6 = OpTypeInt 32 0"},
1069            {7, "%7 = OpPhi %6 %8 %4 %9 %5"},
1070            {8, "%8 = OpConstant %6 0"},
1071            {10, "%10 = OpTypeFloat 32"},
1072            {12, "%12 = OpConstant %10 1.0"},
1073            {13, "%13 = OpFAdd %10 %11 %12"},
1074            {16, "%16 = OpTypeBool"},
1075            {17, "%17 = OpSLessThan %16 %7 %18"},
1076            {18, "%18 = OpConstant %6 1"},
1077            {19, "%19 = OpLabel"},
1078          },
1079          { // uses
1080            {1,
1081              {
1082                "%2 = OpFunction %1 None %3",
1083                "%3 = OpTypeFunction %1",
1084              }
1085            },
1086            {3, {"%2 = OpFunction %1 None %3"}},
1087            {4,
1088              {
1089                "%7 = OpPhi %6 %8 %4 %9 %5",
1090                // "%11 = OpPhi %10 %12 %4 %13 %5",
1091              }
1092            },
1093            {5,
1094              {
1095                "OpBranch %5",
1096                "%7 = OpPhi %6 %8 %4 %9 %5",
1097                // "%11 = OpPhi %10 %12 %4 %13 %5",
1098                "OpLoopMerge %19 %5 None",
1099                "OpBranchConditional %17 %5 %19",
1100              }
1101            },
1102            {6,
1103              {
1104                // Can't properly check constants
1105                // "%8 = OpConstant %6 0",
1106                // "%18 = OpConstant %6 1",
1107                "%7 = OpPhi %6 %8 %4 %9 %5",
1108                // "%9 = OpIAdd %6 %7 %8"
1109              }
1110            },
1111            {7, {"%17 = OpSLessThan %16 %7 %18"}},
1112            {8,
1113              {
1114                "%7 = OpPhi %6 %8 %4 %9 %5",
1115                // "%9 = OpIAdd %6 %7 %8",
1116              }
1117            },
1118            // {9, {"%7 = OpPhi %6 %8 %4 %13 %5"}},
1119            {10,
1120              {
1121                // "%11 = OpPhi %10 %12 %4 %13 %5",
1122                // "%12 = OpConstant %10 1",
1123                "%13 = OpFAdd %10 %11 %12"
1124              }
1125            },
1126            // {11, {"%13 = OpFAdd %10 %11 %12"}},
1127            {12,
1128              {
1129                // "%11 = OpPhi %10 %12 %4 %13 %5",
1130                "%13 = OpFAdd %10 %11 %12"
1131              }
1132            },
1133            // {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}},
1134            {16, {"%17 = OpSLessThan %16 %7 %18"}},
1135            {17, {"OpBranchConditional %17 %5 %19"}},
1136            {18, {"%17 = OpSLessThan %16 %7 %18"}},
1137            {19,
1138              {
1139                "OpLoopMerge %19 %5 None",
1140                "OpBranchConditional %17 %5 %19",
1141              }
1142            },
1143          },
1144        },
1145      },
1146      { // OpPhi defining and referencing the same id.
1147        "%1 = OpTypeBool "
1148        "%3 = OpTypeFunction %1 "
1149        "%2 = OpConstantTrue %1 "
1150        "%4 = OpFunction %3 None %1 "
1151        "%6 = OpLabel "
1152        "     OpBranch %7 "
1153        "%7 = OpLabel "
1154        "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
1155        "     OpBranch %7 "
1156        "     OpFunctionEnd",
1157        {8},
1158        "%1 = OpTypeBool\n"
1159        "%3 = OpTypeFunction %1\n"
1160        "%2 = OpConstantTrue %1\n"
1161
1162        "%4 = OpFunction %3 None %1\n"
1163        "%6 = OpLabel\n"
1164             "OpBranch %7\n"
1165        "%7 = OpLabel\n"
1166             "OpBranch %7\n"
1167             "OpFunctionEnd",
1168        {
1169          { // defs
1170            {1, "%1 = OpTypeBool"},
1171            {2, "%2 = OpConstantTrue %1"},
1172            {3, "%3 = OpTypeFunction %1"},
1173            {4, "%4 = OpFunction %3 None %1"},
1174            {6, "%6 = OpLabel"},
1175            {7, "%7 = OpLabel"},
1176            // {8, "%8 = OpPhi %1 %8 %7 %2 %6"},
1177          },
1178          { // uses
1179            {1,
1180              {
1181                "%2 = OpConstantTrue %1",
1182                "%3 = OpTypeFunction %1",
1183                "%4 = OpFunction %3 None %1",
1184                // "%8 = OpPhi %1 %8 %7 %2 %6",
1185              }
1186            },
1187            // {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1188            {3, {"%4 = OpFunction %3 None %1"}},
1189            // {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1190            {7,
1191              {
1192                "OpBranch %7",
1193                // "%8 = OpPhi %1 %8 %7 %2 %6",
1194                "OpBranch %7",
1195              }
1196            },
1197            // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1198          },
1199        },
1200      },
1201    })
1202);
1203// clang-format on
1204
1205TEST(DefUseTest, OpSwitch) {
1206  // Because disassembler has basic type check for OpSwitch's selector, we
1207  // cannot use the DisassembleInst() in the above. Thus, this special spotcheck
1208  // test case.
1209
1210  const char original_text[] =
1211      // int64 f(int64 v) {
1212      //   switch (v) {
1213      //     case 1:                   break;
1214      //     case -4294967296:         break;
1215      //     case 9223372036854775807: break;
1216      //     default:                  break;
1217      //   }
1218      //   return v;
1219      // }
1220      " %1 = OpTypeInt 64 1 "
1221      " %3 = OpTypePointer Input %1 "
1222      " %2 = OpFunction %1 None %3 "  // %3 is int64(int64)*
1223      " %4 = OpFunctionParameter %1 "
1224      " %5 = OpLabel "
1225      " %6 = OpLoad %1 %4 "  // selector value
1226      "      OpSelectionMerge %7 None "
1227      "      OpSwitch %6 %8 "
1228      "                  1                    %9 "  // 1
1229      "                  -4294967296         %10 "  // -2^32
1230      "                  9223372036854775807 %11 "  // 2^63-1
1231      " %8 = OpLabel "                              // default
1232      "      OpBranch %7 "
1233      " %9 = OpLabel "
1234      "      OpBranch %7 "
1235      "%10 = OpLabel "
1236      "      OpBranch %7 "
1237      "%11 = OpLabel "
1238      "      OpBranch %7 "
1239      " %7 = OpLabel "
1240      "      OpReturnValue %6 "
1241      "      OpFunctionEnd";
1242
1243  std::unique_ptr<IRContext> context =
1244      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text,
1245                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1246  ASSERT_NE(nullptr, context);
1247
1248  // Force a re-build of def-use manager.
1249  context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
1250  (void)context->get_def_use_mgr();
1251
1252  // Do a bunch replacements.
1253  context->ReplaceAllUsesWith(11, 7);   // to existing id
1254  context->ReplaceAllUsesWith(10, 11);  // to existing id
1255  context->ReplaceAllUsesWith(9, 10);   // to existing id
1256
1257  // clang-format off
1258  const char modified_text[] =
1259       "%1 = OpTypeInt 64 1\n"
1260       "%3 = OpTypePointer Input %1\n"
1261       "%2 = OpFunction %1 None %3\n" // %3 is int64(int64)*
1262       "%4 = OpFunctionParameter %1\n"
1263       "%5 = OpLabel\n"
1264       "%6 = OpLoad %1 %4\n" // selector value
1265            "OpSelectionMerge %7 None\n"
1266            "OpSwitch %6 %8 1 %10 -4294967296 %11 9223372036854775807 %7\n" // changed!
1267       "%8 = OpLabel\n"      // default
1268            "OpBranch %7\n"
1269       "%9 = OpLabel\n"
1270            "OpBranch %7\n"
1271      "%10 = OpLabel\n"
1272            "OpBranch %7\n"
1273      "%11 = OpLabel\n"
1274            "OpBranch %7\n"
1275       "%7 = OpLabel\n"
1276            "OpReturnValue %6\n"
1277            "OpFunctionEnd";
1278  // clang-format on
1279
1280  EXPECT_EQ(modified_text, DisassembleModule(context->module()));
1281
1282  InstDefUse def_uses = {};
1283  def_uses.defs = {
1284      {1, "%1 = OpTypeInt 64 1"},
1285      {2, "%2 = OpFunction %1 None %3"},
1286      {3, "%3 = OpTypePointer Input %1"},
1287      {4, "%4 = OpFunctionParameter %1"},
1288      {5, "%5 = OpLabel"},
1289      {6, "%6 = OpLoad %1 %4"},
1290      {7, "%7 = OpLabel"},
1291      {8, "%8 = OpLabel"},
1292      {9, "%9 = OpLabel"},
1293      {10, "%10 = OpLabel"},
1294      {11, "%11 = OpLabel"},
1295  };
1296  CheckDef(def_uses, context->get_def_use_mgr()->id_to_defs());
1297
1298  {
1299    EXPECT_EQ(2u, NumUses(context, 6));
1300    std::vector<spv::Op> opcodes = GetUseOpcodes(context, 6u);
1301    EXPECT_THAT(opcodes, UnorderedElementsAre(spv::Op::OpSwitch,
1302                                              spv::Op::OpReturnValue));
1303  }
1304  {
1305    EXPECT_EQ(6u, NumUses(context, 7));
1306    std::vector<spv::Op> opcodes = GetUseOpcodes(context, 7u);
1307    // OpSwitch is now a user of %7.
1308    EXPECT_THAT(opcodes, UnorderedElementsAre(
1309                             spv::Op::OpSelectionMerge, spv::Op::OpBranch,
1310                             spv::Op::OpBranch, spv::Op::OpBranch,
1311                             spv::Op::OpBranch, spv::Op::OpSwitch));
1312  }
1313  // Check all ids only used by OpSwitch after replacement.
1314  for (const auto id : {8u, 10u, 11u}) {
1315    EXPECT_EQ(1u, NumUses(context, id));
1316    EXPECT_EQ(spv::Op::OpSwitch, GetUseOpcodes(context, id).back());
1317  }
1318}
1319
1320// Test case for analyzing individual instructions.
1321struct AnalyzeInstDefUseTestCase {
1322  const char* module_text;
1323  InstDefUse expected_define_use;
1324};
1325
1326using AnalyzeInstDefUseTest =
1327    ::testing::TestWithParam<AnalyzeInstDefUseTestCase>;
1328
1329// Test the analyzing result for individual instructions.
1330TEST_P(AnalyzeInstDefUseTest, Case) {
1331  auto tc = GetParam();
1332
1333  // Build module.
1334  std::unique_ptr<IRContext> context =
1335      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text);
1336  ASSERT_NE(nullptr, context);
1337
1338  // Analyze the instructions.
1339  DefUseManager manager(context->module());
1340
1341  CheckDef(tc.expected_define_use, manager.id_to_defs());
1342  CheckUse(tc.expected_define_use, &manager, context->module()->IdBound());
1343  // CheckUse(tc.expected_define_use, manager.id_to_uses());
1344}
1345
1346// clang-format off
1347INSTANTIATE_TEST_SUITE_P(
1348    TestCase, AnalyzeInstDefUseTest,
1349    ::testing::ValuesIn(std::vector<AnalyzeInstDefUseTestCase>{
1350      { // A type declaring instruction.
1351        "%1 = OpTypeInt 32 1",
1352        {
1353          // defs
1354          {{1, "%1 = OpTypeInt 32 1"}},
1355          {}, // no uses
1356        },
1357      },
1358      { // A type declaring instruction and a constant value.
1359        "%1 = OpTypeBool "
1360        "%2 = OpConstantTrue %1",
1361        {
1362          { // defs
1363            {1, "%1 = OpTypeBool"},
1364            {2, "%2 = OpConstantTrue %1"},
1365          },
1366          { // uses
1367            {1, {"%2 = OpConstantTrue %1"}},
1368          },
1369        },
1370      },
1371      }));
1372// clang-format on
1373
1374using AnalyzeInstDefUse = ::testing::Test;
1375
1376TEST(AnalyzeInstDefUse, UseWithNoResultId) {
1377  IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr);
1378
1379  // Analyze the instructions.
1380  DefUseManager manager(context.module());
1381
1382  Instruction label(&context, spv::Op::OpLabel, 0, 2, {});
1383  manager.AnalyzeInstDefUse(&label);
1384
1385  Instruction branch(&context, spv::Op::OpBranch, 0, 0,
1386                     {{SPV_OPERAND_TYPE_ID, {2}}});
1387  manager.AnalyzeInstDefUse(&branch);
1388  context.module()->SetIdBound(3);
1389
1390  InstDefUse expected = {
1391      // defs
1392      {
1393          {2, "%2 = OpLabel"},
1394      },
1395      // uses
1396      {{2, {"OpBranch %2"}}},
1397  };
1398
1399  CheckDef(expected, manager.id_to_defs());
1400  CheckUse(expected, &manager, context.module()->IdBound());
1401}
1402
1403TEST(AnalyzeInstDefUse, AddNewInstruction) {
1404  const std::string input = "%1 = OpTypeBool";
1405
1406  // Build module.
1407  std::unique_ptr<IRContext> context =
1408      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
1409  ASSERT_NE(nullptr, context);
1410
1411  // Analyze the instructions.
1412  DefUseManager manager(context->module());
1413
1414  Instruction newInst(context.get(), spv::Op::OpConstantTrue, 1, 2, {});
1415  manager.AnalyzeInstDefUse(&newInst);
1416
1417  InstDefUse expected = {
1418      {
1419          // defs
1420          {1, "%1 = OpTypeBool"},
1421          {2, "%2 = OpConstantTrue %1"},
1422      },
1423      {
1424          // uses
1425          {1, {"%2 = OpConstantTrue %1"}},
1426      },
1427  };
1428
1429  CheckDef(expected, manager.id_to_defs());
1430  CheckUse(expected, &manager, context->module()->IdBound());
1431}
1432
1433struct KillInstTestCase {
1434  const char* before;
1435  std::unordered_set<uint32_t> indices_for_inst_to_kill;
1436  const char* after;
1437  InstDefUse expected_define_use;
1438};
1439
1440using KillInstTest = ::testing::TestWithParam<KillInstTestCase>;
1441
1442TEST_P(KillInstTest, Case) {
1443  auto tc = GetParam();
1444
1445  // Build module.
1446  std::unique_ptr<IRContext> context =
1447      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before,
1448                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1449  ASSERT_NE(nullptr, context);
1450
1451  // Force a re-build of the def-use manager.
1452  context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
1453  (void)context->get_def_use_mgr();
1454
1455  // KillInst
1456  context->module()->ForEachInst([&tc, &context](Instruction* inst) {
1457    if (tc.indices_for_inst_to_kill.count(inst->result_id())) {
1458      context->KillInst(inst);
1459    }
1460  });
1461
1462  EXPECT_EQ(tc.after, DisassembleModule(context->module()));
1463  CheckDef(tc.expected_define_use, context->get_def_use_mgr()->id_to_defs());
1464  CheckUse(tc.expected_define_use, context->get_def_use_mgr(),
1465           context->module()->IdBound());
1466}
1467
1468// clang-format off
1469INSTANTIATE_TEST_SUITE_P(
1470    TestCase, KillInstTest,
1471    ::testing::ValuesIn(std::vector<KillInstTestCase>{
1472      // Kill id defining instructions.
1473      {
1474        "%3 = OpTypeVoid "
1475        "%1 = OpTypeFunction %3 "
1476        "%2 = OpFunction %1 None %3 "
1477        "%4 = OpLabel "
1478        "     OpBranch %5 "
1479        "%5 = OpLabel "
1480        "     OpBranch %6 "
1481        "%6 = OpLabel "
1482        "     OpBranch %4 "
1483        "%7 = OpLabel "
1484        "     OpReturn "
1485        "     OpFunctionEnd",
1486        {3, 5, 7},
1487        "%1 = OpTypeFunction %3\n"
1488        "%2 = OpFunction %1 None %3\n"
1489        "%4 = OpLabel\n"
1490        "OpBranch %5\n"
1491        "OpNop\n"
1492        "OpBranch %6\n"
1493        "%6 = OpLabel\n"
1494        "OpBranch %4\n"
1495        "OpNop\n"
1496        "OpReturn\n"
1497        "OpFunctionEnd",
1498        {
1499          // defs
1500          {
1501            {1, "%1 = OpTypeFunction %3"},
1502            {2, "%2 = OpFunction %1 None %3"},
1503            {4, "%4 = OpLabel"},
1504            {6, "%6 = OpLabel"},
1505          },
1506          // uses
1507          {
1508            {1, {"%2 = OpFunction %1 None %3"}},
1509            {4, {"OpBranch %4"}},
1510            {6, {"OpBranch %6"}},
1511          }
1512        }
1513      },
1514      // Kill instructions that do not have result ids.
1515      {
1516        "%3 = OpTypeVoid "
1517        "%1 = OpTypeFunction %3 "
1518        "%2 = OpFunction %1 None %3 "
1519        "%4 = OpLabel "
1520        "     OpBranch %5 "
1521        "%5 = OpLabel "
1522        "     OpBranch %6 "
1523        "%6 = OpLabel "
1524        "     OpBranch %4 "
1525        "%7 = OpLabel "
1526        "     OpReturn "
1527        "     OpFunctionEnd",
1528        {2, 4},
1529        "%3 = OpTypeVoid\n"
1530        "%1 = OpTypeFunction %3\n"
1531             "OpNop\n"
1532             "OpNop\n"
1533             "OpBranch %5\n"
1534        "%5 = OpLabel\n"
1535             "OpBranch %6\n"
1536        "%6 = OpLabel\n"
1537             "OpBranch %4\n"
1538        "%7 = OpLabel\n"
1539             "OpReturn\n"
1540             "OpFunctionEnd",
1541        {
1542          // defs
1543          {
1544            {1, "%1 = OpTypeFunction %3"},
1545            {3, "%3 = OpTypeVoid"},
1546            {5, "%5 = OpLabel"},
1547            {6, "%6 = OpLabel"},
1548            {7, "%7 = OpLabel"},
1549          },
1550          // uses
1551          {
1552            {3, {"%1 = OpTypeFunction %3"}},
1553            {5, {"OpBranch %5"}},
1554            {6, {"OpBranch %6"}},
1555          }
1556        }
1557      },
1558      }));
1559// clang-format on
1560
1561struct GetAnnotationsTestCase {
1562  const char* code;
1563  uint32_t id;
1564  std::vector<std::string> annotations;
1565};
1566
1567using GetAnnotationsTest = ::testing::TestWithParam<GetAnnotationsTestCase>;
1568
1569TEST_P(GetAnnotationsTest, Case) {
1570  const GetAnnotationsTestCase& tc = GetParam();
1571
1572  // Build module.
1573  std::unique_ptr<IRContext> context =
1574      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code);
1575  ASSERT_NE(nullptr, context);
1576
1577  // Get annotations
1578  DefUseManager manager(context->module());
1579  auto insts = manager.GetAnnotations(tc.id);
1580
1581  // Check
1582  ASSERT_EQ(tc.annotations.size(), insts.size())
1583      << "wrong number of annotation instructions";
1584  auto inst_iter = insts.begin();
1585  for (const std::string& expected_anno_inst : tc.annotations) {
1586    EXPECT_EQ(expected_anno_inst, DisassembleInst(*inst_iter))
1587        << "annotation instruction mismatch";
1588    inst_iter++;
1589  }
1590}
1591
1592// clang-format off
1593INSTANTIATE_TEST_SUITE_P(
1594    TestCase, GetAnnotationsTest,
1595    ::testing::ValuesIn(std::vector<GetAnnotationsTestCase>{
1596      // empty
1597      {"", 0, {}},
1598      // basic
1599      {
1600        // code
1601        "OpDecorate %1 Block "
1602        "OpDecorate %1 RelaxedPrecision "
1603        "%3 = OpTypeInt 32 0 "
1604        "%1 = OpTypeStruct %3",
1605        // id
1606        1,
1607        // annotations
1608        {
1609          "OpDecorate %1 Block",
1610          "OpDecorate %1 RelaxedPrecision",
1611        },
1612      },
1613      // with debug instructions
1614      {
1615        // code
1616        "OpName %1 \"struct_type\" "
1617        "OpName %3 \"int_type\" "
1618        "OpDecorate %1 Block "
1619        "OpDecorate %1 RelaxedPrecision "
1620        "%3 = OpTypeInt 32 0 "
1621        "%1 = OpTypeStruct %3",
1622        // id
1623        1,
1624        // annotations
1625        {
1626          "OpDecorate %1 Block",
1627          "OpDecorate %1 RelaxedPrecision",
1628        },
1629      },
1630      // no annotations
1631      {
1632        // code
1633        "OpName %1 \"struct_type\" "
1634        "OpName %3 \"int_type\" "
1635        "OpDecorate %1 Block "
1636        "OpDecorate %1 RelaxedPrecision "
1637        "%3 = OpTypeInt 32 0 "
1638        "%1 = OpTypeStruct %3",
1639        // id
1640        3,
1641        // annotations
1642        {},
1643      },
1644      // decoration group
1645      {
1646        // code
1647        "OpDecorate %1 Block "
1648        "OpDecorate %1 RelaxedPrecision "
1649        "%1 = OpDecorationGroup "
1650        "OpGroupDecorate %1 %2 %3 "
1651        "%4 = OpTypeInt 32 0 "
1652        "%2 = OpTypeStruct %4 "
1653        "%3 = OpTypeStruct %4 %4",
1654        // id
1655        3,
1656        // annotations
1657        {
1658          "OpGroupDecorate %1 %2 %3",
1659        },
1660      },
1661      // member decorate
1662      {
1663        // code
1664        "OpMemberDecorate %1 0 RelaxedPrecision "
1665        "%2 = OpTypeInt 32 0 "
1666        "%1 = OpTypeStruct %2 %2",
1667        // id
1668        1,
1669        // annotations
1670        {
1671          "OpMemberDecorate %1 0 RelaxedPrecision",
1672        },
1673      },
1674      }));
1675
1676using UpdateUsesTest = PassTest<::testing::Test>;
1677
1678TEST_F(UpdateUsesTest, KeepOldUses) {
1679  const std::vector<const char*> text = {
1680      // clang-format off
1681      "OpCapability Shader",
1682      "%1 = OpExtInstImport \"GLSL.std.450\"",
1683      "OpMemoryModel Logical GLSL450",
1684      "OpEntryPoint Vertex %main \"main\"",
1685      "OpName %main \"main\"",
1686      "%void = OpTypeVoid",
1687      "%4 = OpTypeFunction %void",
1688      "%uint = OpTypeInt 32 0",
1689      "%uint_5 = OpConstant %uint 5",
1690      "%25 = OpConstant %uint 25",
1691      "%main = OpFunction %void None %4",
1692      "%8 = OpLabel",
1693      "%9 = OpIMul %uint %uint_5 %uint_5",
1694      "%10 = OpIMul %uint %9 %uint_5",
1695      "OpReturn",
1696      "OpFunctionEnd"
1697      // clang-format on
1698  };
1699
1700  std::unique_ptr<IRContext> context =
1701      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
1702                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1703  ASSERT_NE(nullptr, context);
1704
1705  DefUseManager* def_use_mgr = context->get_def_use_mgr();
1706  Instruction* def = def_use_mgr->GetDef(9);
1707  Instruction* use = def_use_mgr->GetDef(10);
1708  def->SetOpcode(spv::Op::OpCopyObject);
1709  def->SetInOperands({{SPV_OPERAND_TYPE_ID, {25}}});
1710  context->UpdateDefUse(def);
1711
1712  auto scanUser = [&](Instruction* user) { return user != use; };
1713  bool userFound = !def_use_mgr->WhileEachUser(def, scanUser);
1714
1715  EXPECT_TRUE(userFound);
1716}
1717// clang-format on
1718
1719}  // namespace
1720}  // namespace analysis
1721}  // namespace opt
1722}  // namespace spvtools
1723