1cb93a386Sopenharmony_ci/*
2cb93a386Sopenharmony_ci * Copyright 2021 Google LLC
3cb93a386Sopenharmony_ci *
4cb93a386Sopenharmony_ci * Use of this source code is governed by a BSD-style license that can be
5cb93a386Sopenharmony_ci * found in the LICENSE file.
6cb93a386Sopenharmony_ci */
7cb93a386Sopenharmony_ci
8cb93a386Sopenharmony_ci#include "include/sksl/SkSLErrorReporter.h"
9cb93a386Sopenharmony_ci#include "src/sksl/SkSLAnalysis.h"
10cb93a386Sopenharmony_ci#include "src/sksl/SkSLConstantFolder.h"
11cb93a386Sopenharmony_ci#include "src/sksl/SkSLProgramSettings.h"
12cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLBinaryExpression.h"
13cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLIndexExpression.h"
14cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLLiteral.h"
15cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLSetting.h"
16cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLSwizzle.h"
17cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLTernaryExpression.h"
18cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLType.h"
19cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLVariableReference.h"
20cb93a386Sopenharmony_ci
21cb93a386Sopenharmony_cinamespace SkSL {
22cb93a386Sopenharmony_ci
23cb93a386Sopenharmony_cistatic bool is_low_precision_matrix_vector_multiply(const Expression& left,
24cb93a386Sopenharmony_ci                                                    const Operator& op,
25cb93a386Sopenharmony_ci                                                    const Expression& right,
26cb93a386Sopenharmony_ci                                                    const Type& resultType) {
27cb93a386Sopenharmony_ci    return !resultType.highPrecision() &&
28cb93a386Sopenharmony_ci           op.kind() == Token::Kind::TK_STAR &&
29cb93a386Sopenharmony_ci           left.type().isMatrix() &&
30cb93a386Sopenharmony_ci           right.type().isVector() &&
31cb93a386Sopenharmony_ci           left.type().rows() == right.type().columns() &&
32cb93a386Sopenharmony_ci           Analysis::IsTrivialExpression(left) &&
33cb93a386Sopenharmony_ci           Analysis::IsTrivialExpression(right);
34cb93a386Sopenharmony_ci}
35cb93a386Sopenharmony_ci
36cb93a386Sopenharmony_cistatic std::unique_ptr<Expression> rewrite_matrix_vector_multiply(const Context& context,
37cb93a386Sopenharmony_ci                                                                  const Expression& left,
38cb93a386Sopenharmony_ci                                                                  const Operator& op,
39cb93a386Sopenharmony_ci                                                                  const Expression& right,
40cb93a386Sopenharmony_ci                                                                  const Type& resultType) {
41cb93a386Sopenharmony_ci    // Rewrite m33 * v3 as (m[0] * v[0] + m[1] * v[1] + m[2] * v[2])
42cb93a386Sopenharmony_ci    std::unique_ptr<Expression> sum;
43cb93a386Sopenharmony_ci    for (int n = 0; n < left.type().rows(); ++n) {
44cb93a386Sopenharmony_ci        // Get mat[N] with an index expression.
45cb93a386Sopenharmony_ci        std::unique_ptr<Expression> matN = IndexExpression::Make(
46cb93a386Sopenharmony_ci                context, left.clone(), Literal::MakeInt(context, left.fLine, n));
47cb93a386Sopenharmony_ci        // Get vec[N] with a swizzle expression.
48cb93a386Sopenharmony_ci        std::unique_ptr<Expression> vecN = Swizzle::Make(
49cb93a386Sopenharmony_ci                context, right.clone(), ComponentArray{(SkSL::SwizzleComponent::Type)n});
50cb93a386Sopenharmony_ci        // Multiply them together.
51cb93a386Sopenharmony_ci        const Type* matNType = &matN->type();
52cb93a386Sopenharmony_ci        std::unique_ptr<Expression> product =
53cb93a386Sopenharmony_ci                BinaryExpression::Make(context, std::move(matN), op, std::move(vecN), matNType);
54cb93a386Sopenharmony_ci        // Sum all the components together.
55cb93a386Sopenharmony_ci        if (!sum) {
56cb93a386Sopenharmony_ci            sum = std::move(product);
57cb93a386Sopenharmony_ci        } else {
58cb93a386Sopenharmony_ci            sum = BinaryExpression::Make(context,
59cb93a386Sopenharmony_ci                                         std::move(sum),
60cb93a386Sopenharmony_ci                                         Operator(Token::Kind::TK_PLUS),
61cb93a386Sopenharmony_ci                                         std::move(product),
62cb93a386Sopenharmony_ci                                         matNType);
63cb93a386Sopenharmony_ci        }
64cb93a386Sopenharmony_ci    }
65cb93a386Sopenharmony_ci
66cb93a386Sopenharmony_ci    return sum;
67cb93a386Sopenharmony_ci}
68cb93a386Sopenharmony_ci
69cb93a386Sopenharmony_cistd::unique_ptr<Expression> BinaryExpression::Convert(const Context& context,
70cb93a386Sopenharmony_ci                                                      std::unique_ptr<Expression> left,
71cb93a386Sopenharmony_ci                                                      Operator op,
72cb93a386Sopenharmony_ci                                                      std::unique_ptr<Expression> right) {
73cb93a386Sopenharmony_ci    if (!left || !right) {
74cb93a386Sopenharmony_ci        return nullptr;
75cb93a386Sopenharmony_ci    }
76cb93a386Sopenharmony_ci    const int line = left->fLine;
77cb93a386Sopenharmony_ci
78cb93a386Sopenharmony_ci    const Type* rawLeftType = (left->isIntLiteral() && right->type().isInteger())
79cb93a386Sopenharmony_ci            ? &right->type()
80cb93a386Sopenharmony_ci            : &left->type();
81cb93a386Sopenharmony_ci    const Type* rawRightType = (right->isIntLiteral() && left->type().isInteger())
82cb93a386Sopenharmony_ci            ? &left->type()
83cb93a386Sopenharmony_ci            : &right->type();
84cb93a386Sopenharmony_ci
85cb93a386Sopenharmony_ci    bool isAssignment = op.isAssignment();
86cb93a386Sopenharmony_ci    if (isAssignment &&
87cb93a386Sopenharmony_ci        !Analysis::UpdateVariableRefKind(left.get(),
88cb93a386Sopenharmony_ci                                         op.kind() != Token::Kind::TK_EQ
89cb93a386Sopenharmony_ci                                                 ? VariableReference::RefKind::kReadWrite
90cb93a386Sopenharmony_ci                                                 : VariableReference::RefKind::kWrite,
91cb93a386Sopenharmony_ci                                         context.fErrors)) {
92cb93a386Sopenharmony_ci        return nullptr;
93cb93a386Sopenharmony_ci    }
94cb93a386Sopenharmony_ci
95cb93a386Sopenharmony_ci    const Type* leftType;
96cb93a386Sopenharmony_ci    const Type* rightType;
97cb93a386Sopenharmony_ci    const Type* resultType;
98cb93a386Sopenharmony_ci    if (!op.determineBinaryType(context, *rawLeftType, *rawRightType,
99cb93a386Sopenharmony_ci                                &leftType, &rightType, &resultType)) {
100cb93a386Sopenharmony_ci        context.fErrors->error(line, String("type mismatch: '") + op.operatorName() +
101cb93a386Sopenharmony_ci                                     "' cannot operate on '" + left->type().displayName() +
102cb93a386Sopenharmony_ci                                     "', '" + right->type().displayName() + "'");
103cb93a386Sopenharmony_ci        return nullptr;
104cb93a386Sopenharmony_ci    }
105cb93a386Sopenharmony_ci
106cb93a386Sopenharmony_ci    if (isAssignment && leftType->componentType().isOpaque()) {
107cb93a386Sopenharmony_ci        context.fErrors->error(line, "assignments to opaque type '" + left->type().displayName() +
108cb93a386Sopenharmony_ci                                     "' are not permitted");
109cb93a386Sopenharmony_ci        return nullptr;
110cb93a386Sopenharmony_ci    }
111cb93a386Sopenharmony_ci    if (context.fConfig->strictES2Mode()) {
112cb93a386Sopenharmony_ci        if (!op.isAllowedInStrictES2Mode()) {
113cb93a386Sopenharmony_ci            context.fErrors->error(line, String("operator '") + op.operatorName() +
114cb93a386Sopenharmony_ci                                         "' is not allowed");
115cb93a386Sopenharmony_ci            return nullptr;
116cb93a386Sopenharmony_ci        }
117cb93a386Sopenharmony_ci        if (leftType->isOrContainsArray()) {
118cb93a386Sopenharmony_ci            // Most operators are already rejected on arrays, but GLSL ES 1.0 is very explicit that
119cb93a386Sopenharmony_ci            // the *only* operator allowed on arrays is subscripting (and the rules against
120cb93a386Sopenharmony_ci            // assignment, comparison, and even sequence apply to structs containing arrays as well)
121cb93a386Sopenharmony_ci            context.fErrors->error(line, String("operator '") + op.operatorName() + "' can not "
122cb93a386Sopenharmony_ci                                         "operate on arrays (or structs containing arrays)");
123cb93a386Sopenharmony_ci            return nullptr;
124cb93a386Sopenharmony_ci        }
125cb93a386Sopenharmony_ci    }
126cb93a386Sopenharmony_ci
127cb93a386Sopenharmony_ci    left = leftType->coerceExpression(std::move(left), context);
128cb93a386Sopenharmony_ci    right = rightType->coerceExpression(std::move(right), context);
129cb93a386Sopenharmony_ci    if (!left || !right) {
130cb93a386Sopenharmony_ci        return nullptr;
131cb93a386Sopenharmony_ci    }
132cb93a386Sopenharmony_ci
133cb93a386Sopenharmony_ci    return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
134cb93a386Sopenharmony_ci}
135cb93a386Sopenharmony_ci
136cb93a386Sopenharmony_cistd::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
137cb93a386Sopenharmony_ci                                                   std::unique_ptr<Expression> left,
138cb93a386Sopenharmony_ci                                                   Operator op,
139cb93a386Sopenharmony_ci                                                   std::unique_ptr<Expression> right) {
140cb93a386Sopenharmony_ci    // Determine the result type of the binary expression.
141cb93a386Sopenharmony_ci    const Type* leftType;
142cb93a386Sopenharmony_ci    const Type* rightType;
143cb93a386Sopenharmony_ci    const Type* resultType;
144cb93a386Sopenharmony_ci    SkAssertResult(op.determineBinaryType(context, left->type(), right->type(),
145cb93a386Sopenharmony_ci                                          &leftType, &rightType, &resultType));
146cb93a386Sopenharmony_ci
147cb93a386Sopenharmony_ci    return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
148cb93a386Sopenharmony_ci}
149cb93a386Sopenharmony_ci
150cb93a386Sopenharmony_cistd::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
151cb93a386Sopenharmony_ci                                                   std::unique_ptr<Expression> left,
152cb93a386Sopenharmony_ci                                                   Operator op,
153cb93a386Sopenharmony_ci                                                   std::unique_ptr<Expression> right,
154cb93a386Sopenharmony_ci                                                   const Type* resultType) {
155cb93a386Sopenharmony_ci    // We should have detected non-ES2 compliant behavior in Convert.
156cb93a386Sopenharmony_ci    SkASSERT(!context.fConfig->strictES2Mode() || op.isAllowedInStrictES2Mode());
157cb93a386Sopenharmony_ci    SkASSERT(!context.fConfig->strictES2Mode() || !left->type().isOrContainsArray());
158cb93a386Sopenharmony_ci
159cb93a386Sopenharmony_ci    // We should have detected non-assignable assignment expressions in Convert.
160cb93a386Sopenharmony_ci    SkASSERT(!op.isAssignment() || Analysis::IsAssignable(*left));
161cb93a386Sopenharmony_ci    SkASSERT(!op.isAssignment() || !left->type().componentType().isOpaque());
162cb93a386Sopenharmony_ci
163cb93a386Sopenharmony_ci    // For simple assignments, detect and report out-of-range literal values.
164cb93a386Sopenharmony_ci    if (op.kind() == Token::Kind::TK_EQ) {
165cb93a386Sopenharmony_ci        left->type().checkForOutOfRangeLiteral(context, *right);
166cb93a386Sopenharmony_ci    }
167cb93a386Sopenharmony_ci
168cb93a386Sopenharmony_ci    // Perform constant-folding on the expression.
169cb93a386Sopenharmony_ci    const int line = left->fLine;
170cb93a386Sopenharmony_ci    if (std::unique_ptr<Expression> result = ConstantFolder::Simplify(context, line, *left,
171cb93a386Sopenharmony_ci                                                                      op, *right, *resultType)) {
172cb93a386Sopenharmony_ci        return result;
173cb93a386Sopenharmony_ci    }
174cb93a386Sopenharmony_ci
175cb93a386Sopenharmony_ci    if (context.fConfig->fSettings.fOptimize) {
176cb93a386Sopenharmony_ci        // When sk_Caps.rewriteMatrixVectorMultiply is set, we rewrite medium-precision
177cb93a386Sopenharmony_ci        // matrix * vector multiplication as:
178cb93a386Sopenharmony_ci        //   (sk_Caps.rewriteMatrixVectorMultiply ? (mat[0]*vec[0] + ... + mat[N]*vec[N])
179cb93a386Sopenharmony_ci        //                                        : mat * vec)
180cb93a386Sopenharmony_ci        if (is_low_precision_matrix_vector_multiply(*left, op, *right, *resultType)) {
181cb93a386Sopenharmony_ci            // Look up `sk_Caps.rewriteMatrixVectorMultiply`.
182cb93a386Sopenharmony_ci            auto caps = Setting::Convert(context, line, "rewriteMatrixVectorMultiply");
183cb93a386Sopenharmony_ci
184cb93a386Sopenharmony_ci            bool capsBitIsTrue = caps->isBoolLiteral() && caps->as<Literal>().boolValue();
185cb93a386Sopenharmony_ci            if (capsBitIsTrue || !caps->isBoolLiteral()) {
186cb93a386Sopenharmony_ci                // Rewrite the multiplication as a sum of vector-scalar products.
187cb93a386Sopenharmony_ci                std::unique_ptr<Expression> rewrite =
188cb93a386Sopenharmony_ci                        rewrite_matrix_vector_multiply(context, *left, op, *right, *resultType);
189cb93a386Sopenharmony_ci
190cb93a386Sopenharmony_ci                // If we know the caps bit is true, return the rewritten expression directly.
191cb93a386Sopenharmony_ci                if (capsBitIsTrue) {
192cb93a386Sopenharmony_ci                    return rewrite;
193cb93a386Sopenharmony_ci                }
194cb93a386Sopenharmony_ci
195cb93a386Sopenharmony_ci                // Return a ternary expression:
196cb93a386Sopenharmony_ci                //     sk_Caps.rewriteMatrixVectorMultiply ? (rewrite) : (mat * vec)
197cb93a386Sopenharmony_ci                return TernaryExpression::Make(
198cb93a386Sopenharmony_ci                        context,
199cb93a386Sopenharmony_ci                        std::move(caps),
200cb93a386Sopenharmony_ci                        std::move(rewrite),
201cb93a386Sopenharmony_ci                        std::make_unique<BinaryExpression>(line, std::move(left), op,
202cb93a386Sopenharmony_ci                                                           std::move(right), resultType));
203cb93a386Sopenharmony_ci            }
204cb93a386Sopenharmony_ci        }
205cb93a386Sopenharmony_ci    }
206cb93a386Sopenharmony_ci
207cb93a386Sopenharmony_ci    return std::make_unique<BinaryExpression>(line, std::move(left), op,
208cb93a386Sopenharmony_ci                                              std::move(right), resultType);
209cb93a386Sopenharmony_ci}
210cb93a386Sopenharmony_ci
211cb93a386Sopenharmony_cibool BinaryExpression::CheckRef(const Expression& expr) {
212cb93a386Sopenharmony_ci    switch (expr.kind()) {
213cb93a386Sopenharmony_ci        case Expression::Kind::kFieldAccess:
214cb93a386Sopenharmony_ci            return CheckRef(*expr.as<FieldAccess>().base());
215cb93a386Sopenharmony_ci
216cb93a386Sopenharmony_ci        case Expression::Kind::kIndex:
217cb93a386Sopenharmony_ci            return CheckRef(*expr.as<IndexExpression>().base());
218cb93a386Sopenharmony_ci
219cb93a386Sopenharmony_ci        case Expression::Kind::kSwizzle:
220cb93a386Sopenharmony_ci            return CheckRef(*expr.as<Swizzle>().base());
221cb93a386Sopenharmony_ci
222cb93a386Sopenharmony_ci        case Expression::Kind::kTernary: {
223cb93a386Sopenharmony_ci            const TernaryExpression& t = expr.as<TernaryExpression>();
224cb93a386Sopenharmony_ci            return CheckRef(*t.ifTrue()) && CheckRef(*t.ifFalse());
225cb93a386Sopenharmony_ci        }
226cb93a386Sopenharmony_ci        case Expression::Kind::kVariableReference: {
227cb93a386Sopenharmony_ci            const VariableReference& ref = expr.as<VariableReference>();
228cb93a386Sopenharmony_ci            return ref.refKind() == VariableRefKind::kWrite ||
229cb93a386Sopenharmony_ci                   ref.refKind() == VariableRefKind::kReadWrite;
230cb93a386Sopenharmony_ci        }
231cb93a386Sopenharmony_ci        default:
232cb93a386Sopenharmony_ci            return false;
233cb93a386Sopenharmony_ci    }
234cb93a386Sopenharmony_ci}
235cb93a386Sopenharmony_ci
236cb93a386Sopenharmony_cistd::unique_ptr<Expression> BinaryExpression::clone() const {
237cb93a386Sopenharmony_ci    return std::make_unique<BinaryExpression>(fLine,
238cb93a386Sopenharmony_ci                                              this->left()->clone(),
239cb93a386Sopenharmony_ci                                              this->getOperator(),
240cb93a386Sopenharmony_ci                                              this->right()->clone(),
241cb93a386Sopenharmony_ci                                              &this->type());
242cb93a386Sopenharmony_ci}
243cb93a386Sopenharmony_ci
244cb93a386Sopenharmony_ciString BinaryExpression::description() const {
245cb93a386Sopenharmony_ci    return "(" + this->left()->description() +
246cb93a386Sopenharmony_ci           " " + this->getOperator().operatorName() +
247cb93a386Sopenharmony_ci           " " + this->right()->description() + ")";
248cb93a386Sopenharmony_ci}
249cb93a386Sopenharmony_ci
250cb93a386Sopenharmony_ci}  // namespace SkSL
251