1cb93a386Sopenharmony_ci/*
2cb93a386Sopenharmony_ci * Copyright 2021 Google LLC
3cb93a386Sopenharmony_ci *
4cb93a386Sopenharmony_ci * Use of this source code is governed by a BSD-style license that can be
5cb93a386Sopenharmony_ci * found in the LICENSE file.
6cb93a386Sopenharmony_ci */
7cb93a386Sopenharmony_ci
8cb93a386Sopenharmony_ci#include "src/sksl/SkSLConstantFolder.h"
9cb93a386Sopenharmony_ci#include "src/sksl/SkSLProgramSettings.h"
10cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLBinaryExpression.h"
11cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLConstructorArray.h"
12cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLConstructorCompound.h"
13cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLIndexExpression.h"
14cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLLiteral.h"
15cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLSwizzle.h"
16cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLSymbolTable.h"
17cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLTypeReference.h"
18cb93a386Sopenharmony_ci
19cb93a386Sopenharmony_cinamespace SkSL {
20cb93a386Sopenharmony_ci
21cb93a386Sopenharmony_cistatic bool index_out_of_range(const Context& context, SKSL_INT index, const Expression& base) {
22cb93a386Sopenharmony_ci    if (index >= 0 && index < base.type().columns()) {
23cb93a386Sopenharmony_ci        return false;
24cb93a386Sopenharmony_ci    }
25cb93a386Sopenharmony_ci
26cb93a386Sopenharmony_ci    context.fErrors->error(base.fLine, "index " + to_string(index) + " out of range for '" +
27cb93a386Sopenharmony_ci                                       base.type().displayName() + "'");
28cb93a386Sopenharmony_ci    return true;
29cb93a386Sopenharmony_ci}
30cb93a386Sopenharmony_ci
31cb93a386Sopenharmony_ciconst Type& IndexExpression::IndexType(const Context& context, const Type& type) {
32cb93a386Sopenharmony_ci    if (type.isMatrix()) {
33cb93a386Sopenharmony_ci        if (type.componentType() == *context.fTypes.fFloat) {
34cb93a386Sopenharmony_ci            switch (type.rows()) {
35cb93a386Sopenharmony_ci                case 2: return *context.fTypes.fFloat2;
36cb93a386Sopenharmony_ci                case 3: return *context.fTypes.fFloat3;
37cb93a386Sopenharmony_ci                case 4: return *context.fTypes.fFloat4;
38cb93a386Sopenharmony_ci                default: SkASSERT(false);
39cb93a386Sopenharmony_ci            }
40cb93a386Sopenharmony_ci        } else if (type.componentType() == *context.fTypes.fHalf) {
41cb93a386Sopenharmony_ci            switch (type.rows()) {
42cb93a386Sopenharmony_ci                case 2: return *context.fTypes.fHalf2;
43cb93a386Sopenharmony_ci                case 3: return *context.fTypes.fHalf3;
44cb93a386Sopenharmony_ci                case 4: return *context.fTypes.fHalf4;
45cb93a386Sopenharmony_ci                default: SkASSERT(false);
46cb93a386Sopenharmony_ci            }
47cb93a386Sopenharmony_ci        }
48cb93a386Sopenharmony_ci    }
49cb93a386Sopenharmony_ci    return type.componentType();
50cb93a386Sopenharmony_ci}
51cb93a386Sopenharmony_ci
52cb93a386Sopenharmony_cistd::unique_ptr<Expression> IndexExpression::Convert(const Context& context,
53cb93a386Sopenharmony_ci                                                     SymbolTable& symbolTable,
54cb93a386Sopenharmony_ci                                                     std::unique_ptr<Expression> base,
55cb93a386Sopenharmony_ci                                                     std::unique_ptr<Expression> index) {
56cb93a386Sopenharmony_ci    // Convert an array type reference: `int[10]`.
57cb93a386Sopenharmony_ci    if (base->is<TypeReference>()) {
58cb93a386Sopenharmony_ci        const Type& baseType = base->as<TypeReference>().value();
59cb93a386Sopenharmony_ci        SKSL_INT arraySize = baseType.convertArraySize(context, std::move(index));
60cb93a386Sopenharmony_ci        if (!arraySize) {
61cb93a386Sopenharmony_ci            return nullptr;
62cb93a386Sopenharmony_ci        }
63cb93a386Sopenharmony_ci        return TypeReference::Convert(context, base->fLine,
64cb93a386Sopenharmony_ci                                      symbolTable.addArrayDimension(&baseType, arraySize));
65cb93a386Sopenharmony_ci    }
66cb93a386Sopenharmony_ci    // Convert an index expression with an expression inside of it: `arr[a * 3]`.
67cb93a386Sopenharmony_ci    const Type& baseType = base->type();
68cb93a386Sopenharmony_ci    if (!baseType.isArray() && !baseType.isMatrix() && !baseType.isVector()) {
69cb93a386Sopenharmony_ci        context.fErrors->error(base->fLine,
70cb93a386Sopenharmony_ci                               "expected array, but found '" + baseType.displayName() + "'");
71cb93a386Sopenharmony_ci        return nullptr;
72cb93a386Sopenharmony_ci    }
73cb93a386Sopenharmony_ci    if (!index->type().isInteger()) {
74cb93a386Sopenharmony_ci        index = context.fTypes.fInt->coerceExpression(std::move(index), context);
75cb93a386Sopenharmony_ci        if (!index) {
76cb93a386Sopenharmony_ci            return nullptr;
77cb93a386Sopenharmony_ci        }
78cb93a386Sopenharmony_ci    }
79cb93a386Sopenharmony_ci    // Perform compile-time bounds checking on constant-expression indices.
80cb93a386Sopenharmony_ci    const Expression* indexExpr = ConstantFolder::GetConstantValueForVariable(*index);
81cb93a386Sopenharmony_ci    if (indexExpr->isIntLiteral()) {
82cb93a386Sopenharmony_ci        SKSL_INT indexValue = indexExpr->as<Literal>().intValue();
83cb93a386Sopenharmony_ci        if (index_out_of_range(context, indexValue, *base)) {
84cb93a386Sopenharmony_ci            return nullptr;
85cb93a386Sopenharmony_ci        }
86cb93a386Sopenharmony_ci    }
87cb93a386Sopenharmony_ci    return IndexExpression::Make(context, std::move(base), std::move(index));
88cb93a386Sopenharmony_ci}
89cb93a386Sopenharmony_ci
90cb93a386Sopenharmony_cistd::unique_ptr<Expression> IndexExpression::Make(const Context& context,
91cb93a386Sopenharmony_ci                                                  std::unique_ptr<Expression> base,
92cb93a386Sopenharmony_ci                                                  std::unique_ptr<Expression> index) {
93cb93a386Sopenharmony_ci    const Type& baseType = base->type();
94cb93a386Sopenharmony_ci    SkASSERT(baseType.isArray() || baseType.isMatrix() || baseType.isVector());
95cb93a386Sopenharmony_ci    SkASSERT(index->type().isInteger());
96cb93a386Sopenharmony_ci
97cb93a386Sopenharmony_ci    const Expression* indexExpr = ConstantFolder::GetConstantValueForVariable(*index);
98cb93a386Sopenharmony_ci    if (indexExpr->isIntLiteral()) {
99cb93a386Sopenharmony_ci        SKSL_INT indexValue = indexExpr->as<Literal>().intValue();
100cb93a386Sopenharmony_ci        if (!index_out_of_range(context, indexValue, *base)) {
101cb93a386Sopenharmony_ci            if (baseType.isVector()) {
102cb93a386Sopenharmony_ci                // Constant array indexes on vectors can be converted to swizzles: `v[2]` --> `v.z`.
103cb93a386Sopenharmony_ci                // Swizzling is harmless and can unlock further simplifications for some base types.
104cb93a386Sopenharmony_ci                return Swizzle::Make(context, std::move(base), ComponentArray{(int8_t)indexValue});
105cb93a386Sopenharmony_ci            }
106cb93a386Sopenharmony_ci
107cb93a386Sopenharmony_ci            if (baseType.isArray() && !base->hasSideEffects()) {
108cb93a386Sopenharmony_ci                // Indexing an constant array constructor with a constant index can just pluck out
109cb93a386Sopenharmony_ci                // the requested value from the array.
110cb93a386Sopenharmony_ci                const Expression* baseExpr = ConstantFolder::GetConstantValueForVariable(*base);
111cb93a386Sopenharmony_ci                if (baseExpr->is<ConstructorArray>()) {
112cb93a386Sopenharmony_ci                    const ConstructorArray& arrayCtor = baseExpr->as<ConstructorArray>();
113cb93a386Sopenharmony_ci                    const ExpressionArray& arguments = arrayCtor.arguments();
114cb93a386Sopenharmony_ci                    SkASSERT(arguments.count() == baseType.columns());
115cb93a386Sopenharmony_ci
116cb93a386Sopenharmony_ci                    return arguments[indexValue]->clone();
117cb93a386Sopenharmony_ci                }
118cb93a386Sopenharmony_ci            }
119cb93a386Sopenharmony_ci
120cb93a386Sopenharmony_ci            if (baseType.isMatrix() && !base->hasSideEffects()) {
121cb93a386Sopenharmony_ci                // Matrices can be constructed with vectors that don't line up on column boundaries,
122cb93a386Sopenharmony_ci                // so extracting out the values from the constructor can be tricky. Fortunately, we
123cb93a386Sopenharmony_ci                // can reconstruct an equivalent vector using `getConstantValue`. If we
124cb93a386Sopenharmony_ci                // can't extract the data using `getConstantValue`, it wasn't constant and
125cb93a386Sopenharmony_ci                // we're not obligated to simplify anything.
126cb93a386Sopenharmony_ci                const Expression* baseExpr = ConstantFolder::GetConstantValueForVariable(*base);
127cb93a386Sopenharmony_ci                int vecWidth = baseType.rows();
128cb93a386Sopenharmony_ci                const Type& scalarType = baseType.componentType();
129cb93a386Sopenharmony_ci                const Type& vecType = scalarType.toCompound(context, vecWidth, /*rows=*/1);
130cb93a386Sopenharmony_ci                indexValue *= vecWidth;
131cb93a386Sopenharmony_ci
132cb93a386Sopenharmony_ci                ExpressionArray ctorArgs;
133cb93a386Sopenharmony_ci                ctorArgs.reserve_back(vecWidth);
134cb93a386Sopenharmony_ci                for (int slot = 0; slot < vecWidth; ++slot) {
135cb93a386Sopenharmony_ci                    skstd::optional<double> slotVal = baseExpr->getConstantValue(indexValue + slot);
136cb93a386Sopenharmony_ci                    if (slotVal.has_value()) {
137cb93a386Sopenharmony_ci                        ctorArgs.push_back(Literal::Make(baseExpr->fLine, *slotVal, &scalarType));
138cb93a386Sopenharmony_ci                    } else {
139cb93a386Sopenharmony_ci                        ctorArgs.reset();
140cb93a386Sopenharmony_ci                        break;
141cb93a386Sopenharmony_ci                    }
142cb93a386Sopenharmony_ci                }
143cb93a386Sopenharmony_ci
144cb93a386Sopenharmony_ci                if (!ctorArgs.empty()) {
145cb93a386Sopenharmony_ci                    int line = ctorArgs.front()->fLine;
146cb93a386Sopenharmony_ci                    return ConstructorCompound::Make(context, line, vecType, std::move(ctorArgs));
147cb93a386Sopenharmony_ci                }
148cb93a386Sopenharmony_ci            }
149cb93a386Sopenharmony_ci        }
150cb93a386Sopenharmony_ci    }
151cb93a386Sopenharmony_ci
152cb93a386Sopenharmony_ci    return std::make_unique<IndexExpression>(context, std::move(base), std::move(index));
153cb93a386Sopenharmony_ci}
154cb93a386Sopenharmony_ci
155cb93a386Sopenharmony_ci}  // namespace SkSL
156