1/*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "include/sksl/SkSLErrorReporter.h"
9#include "src/sksl/SkSLAnalysis.h"
10#include "src/sksl/SkSLConstantFolder.h"
11#include "src/sksl/SkSLProgramSettings.h"
12#include "src/sksl/ir/SkSLBinaryExpression.h"
13#include "src/sksl/ir/SkSLIndexExpression.h"
14#include "src/sksl/ir/SkSLLiteral.h"
15#include "src/sksl/ir/SkSLSetting.h"
16#include "src/sksl/ir/SkSLSwizzle.h"
17#include "src/sksl/ir/SkSLTernaryExpression.h"
18#include "src/sksl/ir/SkSLType.h"
19#include "src/sksl/ir/SkSLVariableReference.h"
20
21namespace SkSL {
22
23static bool is_low_precision_matrix_vector_multiply(const Expression& left,
24                                                    const Operator& op,
25                                                    const Expression& right,
26                                                    const Type& resultType) {
27    return !resultType.highPrecision() &&
28           op.kind() == Token::Kind::TK_STAR &&
29           left.type().isMatrix() &&
30           right.type().isVector() &&
31           left.type().rows() == right.type().columns() &&
32           Analysis::IsTrivialExpression(left) &&
33           Analysis::IsTrivialExpression(right);
34}
35
36static std::unique_ptr<Expression> rewrite_matrix_vector_multiply(const Context& context,
37                                                                  const Expression& left,
38                                                                  const Operator& op,
39                                                                  const Expression& right,
40                                                                  const Type& resultType) {
41    // Rewrite m33 * v3 as (m[0] * v[0] + m[1] * v[1] + m[2] * v[2])
42    std::unique_ptr<Expression> sum;
43    for (int n = 0; n < left.type().rows(); ++n) {
44        // Get mat[N] with an index expression.
45        std::unique_ptr<Expression> matN = IndexExpression::Make(
46                context, left.clone(), Literal::MakeInt(context, left.fLine, n));
47        // Get vec[N] with a swizzle expression.
48        std::unique_ptr<Expression> vecN = Swizzle::Make(
49                context, right.clone(), ComponentArray{(SkSL::SwizzleComponent::Type)n});
50        // Multiply them together.
51        const Type* matNType = &matN->type();
52        std::unique_ptr<Expression> product =
53                BinaryExpression::Make(context, std::move(matN), op, std::move(vecN), matNType);
54        // Sum all the components together.
55        if (!sum) {
56            sum = std::move(product);
57        } else {
58            sum = BinaryExpression::Make(context,
59                                         std::move(sum),
60                                         Operator(Token::Kind::TK_PLUS),
61                                         std::move(product),
62                                         matNType);
63        }
64    }
65
66    return sum;
67}
68
69std::unique_ptr<Expression> BinaryExpression::Convert(const Context& context,
70                                                      std::unique_ptr<Expression> left,
71                                                      Operator op,
72                                                      std::unique_ptr<Expression> right) {
73    if (!left || !right) {
74        return nullptr;
75    }
76    const int line = left->fLine;
77
78    const Type* rawLeftType = (left->isIntLiteral() && right->type().isInteger())
79            ? &right->type()
80            : &left->type();
81    const Type* rawRightType = (right->isIntLiteral() && left->type().isInteger())
82            ? &left->type()
83            : &right->type();
84
85    bool isAssignment = op.isAssignment();
86    if (isAssignment &&
87        !Analysis::UpdateVariableRefKind(left.get(),
88                                         op.kind() != Token::Kind::TK_EQ
89                                                 ? VariableReference::RefKind::kReadWrite
90                                                 : VariableReference::RefKind::kWrite,
91                                         context.fErrors)) {
92        return nullptr;
93    }
94
95    const Type* leftType;
96    const Type* rightType;
97    const Type* resultType;
98    if (!op.determineBinaryType(context, *rawLeftType, *rawRightType,
99                                &leftType, &rightType, &resultType)) {
100        context.fErrors->error(line, String("type mismatch: '") + op.operatorName() +
101                                     "' cannot operate on '" + left->type().displayName() +
102                                     "', '" + right->type().displayName() + "'");
103        return nullptr;
104    }
105
106    if (isAssignment && leftType->componentType().isOpaque()) {
107        context.fErrors->error(line, "assignments to opaque type '" + left->type().displayName() +
108                                     "' are not permitted");
109        return nullptr;
110    }
111    if (context.fConfig->strictES2Mode()) {
112        if (!op.isAllowedInStrictES2Mode()) {
113            context.fErrors->error(line, String("operator '") + op.operatorName() +
114                                         "' is not allowed");
115            return nullptr;
116        }
117        if (leftType->isOrContainsArray()) {
118            // Most operators are already rejected on arrays, but GLSL ES 1.0 is very explicit that
119            // the *only* operator allowed on arrays is subscripting (and the rules against
120            // assignment, comparison, and even sequence apply to structs containing arrays as well)
121            context.fErrors->error(line, String("operator '") + op.operatorName() + "' can not "
122                                         "operate on arrays (or structs containing arrays)");
123            return nullptr;
124        }
125    }
126
127    left = leftType->coerceExpression(std::move(left), context);
128    right = rightType->coerceExpression(std::move(right), context);
129    if (!left || !right) {
130        return nullptr;
131    }
132
133    return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
134}
135
136std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
137                                                   std::unique_ptr<Expression> left,
138                                                   Operator op,
139                                                   std::unique_ptr<Expression> right) {
140    // Determine the result type of the binary expression.
141    const Type* leftType;
142    const Type* rightType;
143    const Type* resultType;
144    SkAssertResult(op.determineBinaryType(context, left->type(), right->type(),
145                                          &leftType, &rightType, &resultType));
146
147    return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
148}
149
150std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
151                                                   std::unique_ptr<Expression> left,
152                                                   Operator op,
153                                                   std::unique_ptr<Expression> right,
154                                                   const Type* resultType) {
155    // We should have detected non-ES2 compliant behavior in Convert.
156    SkASSERT(!context.fConfig->strictES2Mode() || op.isAllowedInStrictES2Mode());
157    SkASSERT(!context.fConfig->strictES2Mode() || !left->type().isOrContainsArray());
158
159    // We should have detected non-assignable assignment expressions in Convert.
160    SkASSERT(!op.isAssignment() || Analysis::IsAssignable(*left));
161    SkASSERT(!op.isAssignment() || !left->type().componentType().isOpaque());
162
163    // For simple assignments, detect and report out-of-range literal values.
164    if (op.kind() == Token::Kind::TK_EQ) {
165        left->type().checkForOutOfRangeLiteral(context, *right);
166    }
167
168    // Perform constant-folding on the expression.
169    const int line = left->fLine;
170    if (std::unique_ptr<Expression> result = ConstantFolder::Simplify(context, line, *left,
171                                                                      op, *right, *resultType)) {
172        return result;
173    }
174
175    if (context.fConfig->fSettings.fOptimize) {
176        // When sk_Caps.rewriteMatrixVectorMultiply is set, we rewrite medium-precision
177        // matrix * vector multiplication as:
178        //   (sk_Caps.rewriteMatrixVectorMultiply ? (mat[0]*vec[0] + ... + mat[N]*vec[N])
179        //                                        : mat * vec)
180        if (is_low_precision_matrix_vector_multiply(*left, op, *right, *resultType)) {
181            // Look up `sk_Caps.rewriteMatrixVectorMultiply`.
182            auto caps = Setting::Convert(context, line, "rewriteMatrixVectorMultiply");
183
184            bool capsBitIsTrue = caps->isBoolLiteral() && caps->as<Literal>().boolValue();
185            if (capsBitIsTrue || !caps->isBoolLiteral()) {
186                // Rewrite the multiplication as a sum of vector-scalar products.
187                std::unique_ptr<Expression> rewrite =
188                        rewrite_matrix_vector_multiply(context, *left, op, *right, *resultType);
189
190                // If we know the caps bit is true, return the rewritten expression directly.
191                if (capsBitIsTrue) {
192                    return rewrite;
193                }
194
195                // Return a ternary expression:
196                //     sk_Caps.rewriteMatrixVectorMultiply ? (rewrite) : (mat * vec)
197                return TernaryExpression::Make(
198                        context,
199                        std::move(caps),
200                        std::move(rewrite),
201                        std::make_unique<BinaryExpression>(line, std::move(left), op,
202                                                           std::move(right), resultType));
203            }
204        }
205    }
206
207    return std::make_unique<BinaryExpression>(line, std::move(left), op,
208                                              std::move(right), resultType);
209}
210
211bool BinaryExpression::CheckRef(const Expression& expr) {
212    switch (expr.kind()) {
213        case Expression::Kind::kFieldAccess:
214            return CheckRef(*expr.as<FieldAccess>().base());
215
216        case Expression::Kind::kIndex:
217            return CheckRef(*expr.as<IndexExpression>().base());
218
219        case Expression::Kind::kSwizzle:
220            return CheckRef(*expr.as<Swizzle>().base());
221
222        case Expression::Kind::kTernary: {
223            const TernaryExpression& t = expr.as<TernaryExpression>();
224            return CheckRef(*t.ifTrue()) && CheckRef(*t.ifFalse());
225        }
226        case Expression::Kind::kVariableReference: {
227            const VariableReference& ref = expr.as<VariableReference>();
228            return ref.refKind() == VariableRefKind::kWrite ||
229                   ref.refKind() == VariableRefKind::kReadWrite;
230        }
231        default:
232            return false;
233    }
234}
235
236std::unique_ptr<Expression> BinaryExpression::clone() const {
237    return std::make_unique<BinaryExpression>(fLine,
238                                              this->left()->clone(),
239                                              this->getOperator(),
240                                              this->right()->clone(),
241                                              &this->type());
242}
243
244String BinaryExpression::description() const {
245    return "(" + this->left()->description() +
246           " " + this->getOperator().operatorName() +
247           " " + this->right()->description() + ")";
248}
249
250}  // namespace SkSL
251