1/*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/ir/SkSLSwitchStatement.h"
9
10#include <forward_list>
11
12#include "include/private/SkTHash.h"
13#include "src/sksl/SkSLAnalysis.h"
14#include "src/sksl/SkSLConstantFolder.h"
15#include "src/sksl/SkSLContext.h"
16#include "src/sksl/SkSLProgramSettings.h"
17#include "src/sksl/ir/SkSLBlock.h"
18#include "src/sksl/ir/SkSLNop.h"
19#include "src/sksl/ir/SkSLSymbolTable.h"
20#include "src/sksl/ir/SkSLType.h"
21
22namespace SkSL {
23
24std::unique_ptr<Statement> SwitchStatement::clone() const {
25    StatementArray cases;
26    cases.reserve_back(this->cases().size());
27    for (const std::unique_ptr<Statement>& stmt : this->cases()) {
28        cases.push_back(stmt->clone());
29    }
30    return std::make_unique<SwitchStatement>(fLine,
31                                             this->isStatic(),
32                                             this->value()->clone(),
33                                             std::move(cases),
34                                             SymbolTable::WrapIfBuiltin(this->symbols()));
35}
36
37String SwitchStatement::description() const {
38    String result;
39    if (this->isStatic()) {
40        result += "@";
41    }
42    result += String::printf("switch (%s) {\n", this->value()->description().c_str());
43    for (const auto& c : this->cases()) {
44        result += c->description();
45    }
46    result += "}";
47    return result;
48}
49
50static std::forward_list<const SwitchCase*> find_duplicate_case_values(
51        const StatementArray& cases) {
52    std::forward_list<const SwitchCase*> duplicateCases;
53    SkTHashSet<SKSL_INT> intValues;
54    bool foundDefault = false;
55
56    for (const std::unique_ptr<Statement>& stmt : cases) {
57        const SwitchCase* sc = &stmt->as<SwitchCase>();
58        const std::unique_ptr<Expression>& valueExpr = sc->value();
59
60        // A null case-value indicates the `default` switch-case.
61        if (!valueExpr) {
62            if (foundDefault) {
63                duplicateCases.push_front(sc);
64                continue;
65            }
66            foundDefault = true;
67            continue;
68        }
69
70        // GetConstantInt already succeeded when the SwitchCase was first assembled, so it should
71        // succeed this time too.
72        SKSL_INT intValue = 0;
73        SkAssertResult(ConstantFolder::GetConstantInt(*valueExpr, &intValue));
74        if (intValues.contains(intValue)) {
75            duplicateCases.push_front(sc);
76            continue;
77        }
78        intValues.add(intValue);
79    }
80
81    return duplicateCases;
82}
83
84static void move_all_but_break(std::unique_ptr<Statement>& stmt, StatementArray* target) {
85    switch (stmt->kind()) {
86        case Statement::Kind::kBlock: {
87            // Recurse into the block.
88            Block& block = stmt->as<Block>();
89
90            StatementArray blockStmts;
91            blockStmts.reserve_back(block.children().size());
92            for (std::unique_ptr<Statement>& blockStmt : block.children()) {
93                move_all_but_break(blockStmt, &blockStmts);
94            }
95
96            target->push_back(Block::Make(block.fLine, std::move(blockStmts),
97                                          block.symbolTable(), block.isScope()));
98            break;
99        }
100
101        case Statement::Kind::kBreak:
102            // Do not append a break to the target.
103            break;
104
105        default:
106            // Append normal statements to the target.
107            target->push_back(std::move(stmt));
108            break;
109    }
110}
111
112std::unique_ptr<Statement> SwitchStatement::BlockForCase(StatementArray* cases,
113                                                         SwitchCase* caseToCapture,
114                                                         std::shared_ptr<SymbolTable> symbolTable) {
115    // We have to be careful to not move any of the pointers until after we're sure we're going to
116    // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
117    // of action. First, find the switch-case we are interested in.
118    auto iter = cases->begin();
119    for (; iter != cases->end(); ++iter) {
120        const SwitchCase& sc = (*iter)->as<SwitchCase>();
121        if (&sc == caseToCapture) {
122            break;
123        }
124    }
125
126    // Next, walk forward through the rest of the switch. If we find a conditional break, we're
127    // stuck and can't simplify at all. If we find an unconditional break, we have a range of
128    // statements that we can use for simplification.
129    auto startIter = iter;
130    Statement* stripBreakStmt = nullptr;
131    for (; iter != cases->end(); ++iter) {
132        std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
133        if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
134            // We can't reduce switch-cases to a block when they have conditional exits.
135            return nullptr;
136        }
137        if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
138            // We found an unconditional exit. We can use this block, but we'll need to strip
139            // out the break statement if there is one.
140            stripBreakStmt = stmt.get();
141            break;
142        }
143    }
144
145    // We fell off the bottom of the switch or encountered a break. We know the range of statements
146    // that we need to move over, and we know it's safe to do so.
147    StatementArray caseStmts;
148    caseStmts.reserve_back(std::distance(startIter, iter) + 1);
149
150    // We can move over most of the statements as-is.
151    while (startIter != iter) {
152        caseStmts.push_back(std::move((*startIter)->as<SwitchCase>().statement()));
153        ++startIter;
154    }
155
156    // If we found an unconditional break at the end, we need to move what we can while avoiding
157    // that break.
158    if (stripBreakStmt != nullptr) {
159        SkASSERT((*startIter)->as<SwitchCase>().statement().get() == stripBreakStmt);
160        move_all_but_break((*startIter)->as<SwitchCase>().statement(), &caseStmts);
161    }
162
163    // Return our newly-synthesized block.
164    return Block::Make(caseToCapture->fLine, std::move(caseStmts), std::move(symbolTable));
165}
166
167std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
168                                                    int line,
169                                                    bool isStatic,
170                                                    std::unique_ptr<Expression> value,
171                                                    ExpressionArray caseValues,
172                                                    StatementArray caseStatements,
173                                                    std::shared_ptr<SymbolTable> symbolTable) {
174    SkASSERT(caseValues.size() == caseStatements.size());
175
176    value = context.fTypes.fInt->coerceExpression(std::move(value), context);
177    if (!value) {
178        return nullptr;
179    }
180
181    StatementArray cases;
182    for (int i = 0; i < caseValues.count(); ++i) {
183        int caseLine;
184        std::unique_ptr<Expression> caseValue;
185        if (caseValues[i]) {
186            caseLine = caseValues[i]->fLine;
187
188            // Case values must be the same type as the switch value--`int` or a particular enum.
189            caseValue = value->type().coerceExpression(std::move(caseValues[i]), context);
190            if (!caseValue) {
191                return nullptr;
192            }
193            // Case values must be a literal integer or a `const int` variable reference.
194            SKSL_INT intValue;
195            if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
196                context.fErrors->error(caseValue->fLine, "case value must be a constant integer");
197                return nullptr;
198            }
199        } else {
200            // The null case-expression corresponds to `default:`.
201            caseLine = line;
202        }
203        cases.push_back(std::make_unique<SwitchCase>(caseLine, std::move(caseValue),
204                                                     std::move(caseStatements[i])));
205    }
206
207    // Detect duplicate `case` labels and report an error.
208    // (Using forward_list here to optimize for the common case of no results.)
209    std::forward_list<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
210    if (!duplicateCases.empty()) {
211        duplicateCases.reverse();
212        for (const SwitchCase* sc : duplicateCases) {
213            if (sc->value() != nullptr) {
214                context.fErrors->error(sc->fLine,
215                                       "duplicate case value '" + sc->value()->description() + "'");
216            } else {
217                context.fErrors->error(sc->fLine, "duplicate default case");
218            }
219        }
220        return nullptr;
221    }
222
223    return SwitchStatement::Make(context, line, isStatic, std::move(value), std::move(cases),
224                                 std::move(symbolTable));
225}
226
227std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
228                                                 int line,
229                                                 bool isStatic,
230                                                 std::unique_ptr<Expression> value,
231                                                 StatementArray cases,
232                                                 std::shared_ptr<SymbolTable> symbolTable) {
233    // Confirm that every statement in `cases` is a SwitchCase.
234    SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
235        return stmt->is<SwitchCase>();
236    }));
237
238    // Confirm that every switch-case has been coerced to the proper type.
239    SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
240        return !stmt->as<SwitchCase>().value() ||  // `default` case has a null value
241               value->type() == stmt->as<SwitchCase>().value()->type();
242    }));
243
244    // Confirm that every switch-case value is unique.
245    SkASSERT(find_duplicate_case_values(cases).empty());
246
247    // Flatten @switch statements.
248    if (isStatic || context.fConfig->fSettings.fOptimize) {
249        SKSL_INT switchValue;
250        if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
251            SwitchCase* defaultCase = nullptr;
252            SwitchCase* matchingCase = nullptr;
253            for (const std::unique_ptr<Statement>& stmt : cases) {
254                SwitchCase& sc = stmt->as<SwitchCase>();
255                if (!sc.value()) {
256                    defaultCase = &sc;
257                    continue;
258                }
259
260                SKSL_INT caseValue;
261                SkAssertResult(ConstantFolder::GetConstantInt(*sc.value(), &caseValue));
262                if (caseValue == switchValue) {
263                    matchingCase = &sc;
264                    break;
265                }
266            }
267
268            if (!matchingCase) {
269                // No case value matches the switch value.
270                if (!defaultCase) {
271                    // No default switch-case exists; the switch had no effect.
272                    // We can eliminate the entire switch!
273                    return Nop::Make();
274                }
275                // We had a default case; that's what we matched with.
276                matchingCase = defaultCase;
277            }
278
279            // Convert the switch-case that we matched with into a block.
280            std::unique_ptr<Statement> newBlock = BlockForCase(&cases, matchingCase, symbolTable);
281            if (newBlock) {
282                return newBlock;
283            }
284
285            // Report an error if this was a static switch and BlockForCase failed us.
286            if (isStatic && !context.fConfig->fSettings.fPermitInvalidStaticTests) {
287                context.fErrors->error(value->fLine,
288                                       "static switch contains non-static conditional exit");
289                return nullptr;
290            }
291        }
292    }
293
294    // The switch couldn't be optimized away; emit it normally.
295    return std::make_unique<SwitchStatement>(line, isStatic, std::move(value), std::move(cases),
296                                             std::move(symbolTable));
297}
298
299}  // namespace SkSL
300