1/*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "include/private/SkFloatingPoint.h"
9#include "include/sksl/SkSLErrorReporter.h"
10#include "src/sksl/SkSLAnalysis.h"
11#include "src/sksl/SkSLConstantFolder.h"
12#include "src/sksl/ir/SkSLBinaryExpression.h"
13#include "src/sksl/ir/SkSLForStatement.h"
14#include "src/sksl/ir/SkSLPostfixExpression.h"
15#include "src/sksl/ir/SkSLPrefixExpression.h"
16#include "src/sksl/ir/SkSLVarDeclarations.h"
17#include "src/sksl/ir/SkSLVariableReference.h"
18
19#include <cmath>
20#include <memory>
21
22namespace SkSL {
23
24// Loops that run for 100000+ iterations will exceed our program size limit.
25static constexpr int kLoopTerminationLimit = 100000;
26
27static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive) {
28    if (forwards != (start < end)) {
29        // The loop starts in a completed state (the start has already advanced past the end).
30        return 0;
31    }
32    if ((delta == 0.0) || forwards != (delta > 0.0)) {
33        // The loop does not progress toward a completed state, and will never terminate.
34        return kLoopTerminationLimit;
35    }
36    double iterations = sk_ieee_double_divide(end - start, delta);
37    double count = std::ceil(iterations);
38    if (inclusive && (count == iterations)) {
39        count += 1.0;
40    }
41    if (count > kLoopTerminationLimit || !std::isfinite(count)) {
42        // The loop runs for more iterations than we can safely unroll.
43        return kLoopTerminationLimit;
44    }
45    return (int)count;
46}
47
48static const char* get_es2_loop_unroll_info(const Statement* loopInitializer,
49                                            const Expression* loopTest,
50                                            const Expression* loopNext,
51                                            const Statement* loopStatement,
52                                            LoopUnrollInfo& loopInfo) {
53    //
54    // init_declaration has the form: type_specifier identifier = constant_expression
55    //
56    if (!loopInitializer) {
57        return "missing init declaration";
58    }
59    if (!loopInitializer->is<VarDeclaration>()) {
60        return "invalid init declaration";
61    }
62    const VarDeclaration& initDecl = loopInitializer->as<VarDeclaration>();
63    if (!initDecl.baseType().isNumber()) {
64        return "invalid type for loop index";
65    }
66    if (initDecl.arraySize() != 0) {
67        return "invalid type for loop index";
68    }
69    if (!initDecl.value()) {
70        return "missing loop index initializer";
71    }
72    if (!ConstantFolder::GetConstantValue(*initDecl.value(), &loopInfo.fStart)) {
73        return "loop index initializer must be a constant expression";
74    }
75
76    loopInfo.fIndex = &initDecl.var();
77
78    auto is_loop_index = [&](const std::unique_ptr<Expression>& expr) {
79        return expr->is<VariableReference>() &&
80               expr->as<VariableReference>().variable() == loopInfo.fIndex;
81    };
82
83    //
84    // condition has the form: loop_index relational_operator constant_expression
85    //
86    if (!loopTest) {
87        return "missing condition";
88    }
89    if (!loopTest->is<BinaryExpression>()) {
90        return "invalid condition";
91    }
92    const BinaryExpression& cond = loopTest->as<BinaryExpression>();
93    if (!is_loop_index(cond.left())) {
94        return "expected loop index on left hand side of condition";
95    }
96    // relational_operator is one of: > >= < <= == or !=
97    switch (cond.getOperator().kind()) {
98        case Token::Kind::TK_GT:
99        case Token::Kind::TK_GTEQ:
100        case Token::Kind::TK_LT:
101        case Token::Kind::TK_LTEQ:
102        case Token::Kind::TK_EQEQ:
103        case Token::Kind::TK_NEQ:
104            break;
105        default:
106            return "invalid relational operator";
107    }
108    double loopEnd = 0;
109    if (!ConstantFolder::GetConstantValue(*cond.right(), &loopEnd)) {
110        return "loop index must be compared with a constant expression";
111    }
112
113    //
114    // expression has one of the following forms:
115    //   loop_index++
116    //   loop_index--
117    //   loop_index += constant_expression
118    //   loop_index -= constant_expression
119    // The spec doesn't mention prefix increment and decrement, but there is some consensus that
120    // it's an oversight, so we allow those as well.
121    //
122    if (!loopNext) {
123        return "missing loop expression";
124    }
125    switch (loopNext->kind()) {
126        case Expression::Kind::kBinary: {
127            const BinaryExpression& next = loopNext->as<BinaryExpression>();
128            if (!is_loop_index(next.left())) {
129                return "expected loop index in loop expression";
130            }
131            if (!ConstantFolder::GetConstantValue(*next.right(), &loopInfo.fDelta)) {
132                return "loop index must be modified by a constant expression";
133            }
134            switch (next.getOperator().kind()) {
135                case Token::Kind::TK_PLUSEQ:                                      break;
136                case Token::Kind::TK_MINUSEQ: loopInfo.fDelta = -loopInfo.fDelta; break;
137                default:
138                    return "invalid operator in loop expression";
139            }
140        } break;
141        case Expression::Kind::kPrefix: {
142            const PrefixExpression& next = loopNext->as<PrefixExpression>();
143            if (!is_loop_index(next.operand())) {
144                return "expected loop index in loop expression";
145            }
146            switch (next.getOperator().kind()) {
147                case Token::Kind::TK_PLUSPLUS:   loopInfo.fDelta =  1; break;
148                case Token::Kind::TK_MINUSMINUS: loopInfo.fDelta = -1; break;
149                default:
150                    return "invalid operator in loop expression";
151            }
152        } break;
153        case Expression::Kind::kPostfix: {
154            const PostfixExpression& next = loopNext->as<PostfixExpression>();
155            if (!is_loop_index(next.operand())) {
156                return "expected loop index in loop expression";
157            }
158            switch (next.getOperator().kind()) {
159                case Token::Kind::TK_PLUSPLUS:   loopInfo.fDelta =  1; break;
160                case Token::Kind::TK_MINUSMINUS: loopInfo.fDelta = -1; break;
161                default:
162                    return "invalid operator in loop expression";
163            }
164        } break;
165        default:
166            return "invalid loop expression";
167    }
168
169    //
170    // Within the body of the loop, the loop index is not statically assigned to, nor is it used as
171    // argument to a function 'out' or 'inout' parameter.
172    //
173    if (Analysis::StatementWritesToVariable(*loopStatement, initDecl.var())) {
174        return "loop index must not be modified within body of the loop";
175    }
176
177    // Finally, compute the iteration count, based on the bounds, and the termination operator.
178    loopInfo.fCount = 0;
179
180    switch (cond.getOperator().kind()) {
181        case Token::Kind::TK_LT:
182            loopInfo.fCount = calculate_count(loopInfo.fStart, loopEnd, loopInfo.fDelta,
183                                              /*forwards=*/true, /*inclusive=*/false);
184            break;
185
186        case Token::Kind::TK_GT:
187            loopInfo.fCount = calculate_count(loopInfo.fStart, loopEnd, loopInfo.fDelta,
188                                              /*forwards=*/false, /*inclusive=*/false);
189            break;
190
191        case Token::Kind::TK_LTEQ:
192            loopInfo.fCount = calculate_count(loopInfo.fStart, loopEnd, loopInfo.fDelta,
193                                              /*forwards=*/true, /*inclusive=*/true);
194            break;
195
196        case Token::Kind::TK_GTEQ:
197            loopInfo.fCount = calculate_count(loopInfo.fStart, loopEnd, loopInfo.fDelta,
198                                              /*forwards=*/false, /*inclusive=*/true);
199            break;
200
201        case Token::Kind::TK_NEQ: {
202            float iterations = sk_ieee_double_divide(loopEnd - loopInfo.fStart, loopInfo.fDelta);
203            loopInfo.fCount = std::ceil(iterations);
204            if (loopInfo.fCount < 0 || loopInfo.fCount != iterations ||
205                !std::isfinite(iterations)) {
206                // The loop doesn't reach the exact endpoint and so will never terminate.
207                loopInfo.fCount = kLoopTerminationLimit;
208            }
209            break;
210        }
211        case Token::Kind::TK_EQEQ: {
212            if (loopInfo.fStart == loopEnd) {
213                // Start and end begin in the same place, so we can run one iteration...
214                if (loopInfo.fDelta) {
215                    // ... and then they diverge, so the loop terminates.
216                    loopInfo.fCount = 1;
217                } else {
218                    // ... but they never diverge, so the loop runs forever.
219                    loopInfo.fCount = kLoopTerminationLimit;
220                }
221            } else {
222                // Start never equals end, so the loop will not run a single iteration.
223                loopInfo.fCount = 0;
224            }
225            break;
226        }
227        default: SkUNREACHABLE;
228    }
229
230    SkASSERT(loopInfo.fCount >= 0);
231    if (loopInfo.fCount >= kLoopTerminationLimit) {
232        return "loop must guarantee termination in fewer iterations";
233    }
234
235    return nullptr;  // All checks pass
236}
237
238std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(int line,
239                                                            const Statement* loopInitializer,
240                                                            const Expression* loopTest,
241                                                            const Expression* loopNext,
242                                                            const Statement* loopStatement,
243                                                            ErrorReporter* errors) {
244    auto result = std::make_unique<LoopUnrollInfo>();
245    if (const char* msg = get_es2_loop_unroll_info(loopInitializer, loopTest, loopNext,
246                                                   loopStatement, *result)) {
247        result = nullptr;
248        if (errors) {
249            errors->error(line, msg);
250        }
251    }
252    return result;
253}
254
255}  // namespace SkSL
256