1cb93a386Sopenharmony_ci/*
2cb93a386Sopenharmony_ci * Copyright 2021 Google LLC
3cb93a386Sopenharmony_ci *
4cb93a386Sopenharmony_ci * Use of this source code is governed by a BSD-style license that can be
5cb93a386Sopenharmony_ci * found in the LICENSE file.
6cb93a386Sopenharmony_ci */
7cb93a386Sopenharmony_ci
8cb93a386Sopenharmony_ci#include "include/sksl/DSLCore.h"
9cb93a386Sopenharmony_ci#include "src/core/SkSafeMath.h"
10cb93a386Sopenharmony_ci#include "src/sksl/SkSLAnalysis.h"
11cb93a386Sopenharmony_ci#include "src/sksl/SkSLCompiler.h"
12cb93a386Sopenharmony_ci#include "src/sksl/SkSLContext.h"
13cb93a386Sopenharmony_ci#include "src/sksl/SkSLIntrinsicMap.h"
14cb93a386Sopenharmony_ci#include "src/sksl/SkSLProgramSettings.h"
15cb93a386Sopenharmony_ci#include "src/sksl/SkSLThreadContext.h"
16cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLFieldAccess.h"
17cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLFunctionCall.h"
18cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLFunctionDefinition.h"
19cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLInterfaceBlock.h"
20cb93a386Sopenharmony_ci#include "src/sksl/ir/SkSLReturnStatement.h"
21cb93a386Sopenharmony_ci#include "src/sksl/transform/SkSLProgramWriter.h"
22cb93a386Sopenharmony_ci
23cb93a386Sopenharmony_ci#include <forward_list>
24cb93a386Sopenharmony_ci
25cb93a386Sopenharmony_cinamespace SkSL {
26cb93a386Sopenharmony_ci
27cb93a386Sopenharmony_cistatic void append_rtadjust_fixup_to_vertex_main(const Context& context,
28cb93a386Sopenharmony_ci        const FunctionDeclaration& decl, Block& body) {
29cb93a386Sopenharmony_ci    using namespace SkSL::dsl;
30cb93a386Sopenharmony_ci    using SkSL::dsl::Swizzle;  // disambiguate from SkSL::Swizzle
31cb93a386Sopenharmony_ci    using OwnerKind = SkSL::FieldAccess::OwnerKind;
32cb93a386Sopenharmony_ci
33cb93a386Sopenharmony_ci    // If this program uses RTAdjust...
34cb93a386Sopenharmony_ci    ThreadContext::RTAdjustData& rtAdjust = ThreadContext::RTAdjustState();
35cb93a386Sopenharmony_ci    if (rtAdjust.fVar || rtAdjust.fInterfaceBlock) {
36cb93a386Sopenharmony_ci        // ...append a line to the end of the function body which fixes up sk_Position.
37cb93a386Sopenharmony_ci        const Variable* skPerVertex = nullptr;
38cb93a386Sopenharmony_ci        if (const ProgramElement* perVertexDecl =
39cb93a386Sopenharmony_ci                context.fIntrinsics->find(Compiler::PERVERTEX_NAME)) {
40cb93a386Sopenharmony_ci            SkASSERT(perVertexDecl->is<SkSL::InterfaceBlock>());
41cb93a386Sopenharmony_ci            skPerVertex = &perVertexDecl->as<SkSL::InterfaceBlock>().variable();
42cb93a386Sopenharmony_ci        }
43cb93a386Sopenharmony_ci
44cb93a386Sopenharmony_ci        SkASSERT(skPerVertex);
45cb93a386Sopenharmony_ci        auto Ref = [](const Variable* var) -> std::unique_ptr<Expression> {
46cb93a386Sopenharmony_ci            return VariableReference::Make(/*line=*/-1, var);
47cb93a386Sopenharmony_ci        };
48cb93a386Sopenharmony_ci        auto Field = [&](const Variable* var, int idx) -> std::unique_ptr<Expression> {
49cb93a386Sopenharmony_ci            return FieldAccess::Make(context, Ref(var), idx, OwnerKind::kAnonymousInterfaceBlock);
50cb93a386Sopenharmony_ci        };
51cb93a386Sopenharmony_ci        auto Pos = [&]() -> DSLExpression {
52cb93a386Sopenharmony_ci            return DSLExpression(FieldAccess::Make(context, Ref(skPerVertex), /*fieldIndex=*/0,
53cb93a386Sopenharmony_ci                                                   OwnerKind::kAnonymousInterfaceBlock));
54cb93a386Sopenharmony_ci        };
55cb93a386Sopenharmony_ci        auto Adjust = [&]() -> DSLExpression {
56cb93a386Sopenharmony_ci            return DSLExpression(rtAdjust.fInterfaceBlock
57cb93a386Sopenharmony_ci                                         ? Field(rtAdjust.fInterfaceBlock, rtAdjust.fFieldIndex)
58cb93a386Sopenharmony_ci                                         : Ref(rtAdjust.fVar));
59cb93a386Sopenharmony_ci        };
60cb93a386Sopenharmony_ci
61cb93a386Sopenharmony_ci        auto fixupStmt = DSLStatement(
62cb93a386Sopenharmony_ci            Pos() = Float4(Swizzle(Pos(), X, Y) * Swizzle(Adjust(), X, Z) +
63cb93a386Sopenharmony_ci                           Swizzle(Pos(), W, W) * Swizzle(Adjust(), Y, W),
64cb93a386Sopenharmony_ci                           0,
65cb93a386Sopenharmony_ci                           Pos().w())
66cb93a386Sopenharmony_ci        );
67cb93a386Sopenharmony_ci
68cb93a386Sopenharmony_ci        body.children().push_back(fixupStmt.release());
69cb93a386Sopenharmony_ci    }
70cb93a386Sopenharmony_ci}
71cb93a386Sopenharmony_ci
72cb93a386Sopenharmony_cistd::unique_ptr<FunctionDefinition> FunctionDefinition::Convert(const Context& context,
73cb93a386Sopenharmony_ci                                                                int line,
74cb93a386Sopenharmony_ci                                                                const FunctionDeclaration& function,
75cb93a386Sopenharmony_ci                                                                std::unique_ptr<Statement> body,
76cb93a386Sopenharmony_ci                                                                bool builtin) {
77cb93a386Sopenharmony_ci    class Finalizer : public ProgramWriter {
78cb93a386Sopenharmony_ci    public:
79cb93a386Sopenharmony_ci        Finalizer(const Context& context, const FunctionDeclaration& function,
80cb93a386Sopenharmony_ci                  IntrinsicSet* referencedIntrinsics)
81cb93a386Sopenharmony_ci            : fContext(context)
82cb93a386Sopenharmony_ci            , fFunction(function)
83cb93a386Sopenharmony_ci            , fReferencedIntrinsics(referencedIntrinsics) {}
84cb93a386Sopenharmony_ci
85cb93a386Sopenharmony_ci        ~Finalizer() override {
86cb93a386Sopenharmony_ci            SkASSERT(fBreakableLevel == 0);
87cb93a386Sopenharmony_ci            SkASSERT(fContinuableLevel == std::forward_list<int>{0});
88cb93a386Sopenharmony_ci        }
89cb93a386Sopenharmony_ci
90cb93a386Sopenharmony_ci        void copyIntrinsicIfNeeded(const FunctionDeclaration& function) {
91cb93a386Sopenharmony_ci            if (const ProgramElement* found =
92cb93a386Sopenharmony_ci                    fContext.fIntrinsics->findAndInclude(function.description())) {
93cb93a386Sopenharmony_ci                const FunctionDefinition& original = found->as<FunctionDefinition>();
94cb93a386Sopenharmony_ci
95cb93a386Sopenharmony_ci                // Sort the referenced intrinsics into a consistent order; otherwise our output will
96cb93a386Sopenharmony_ci                // become non-deterministic.
97cb93a386Sopenharmony_ci                std::vector<const FunctionDeclaration*> intrinsics(
98cb93a386Sopenharmony_ci                        original.referencedIntrinsics().begin(),
99cb93a386Sopenharmony_ci                        original.referencedIntrinsics().end());
100cb93a386Sopenharmony_ci                std::sort(intrinsics.begin(), intrinsics.end(),
101cb93a386Sopenharmony_ci                          [](const FunctionDeclaration* a, const FunctionDeclaration* b) {
102cb93a386Sopenharmony_ci                              if (a->isBuiltin() != b->isBuiltin()) {
103cb93a386Sopenharmony_ci                                  return a->isBuiltin() < b->isBuiltin();
104cb93a386Sopenharmony_ci                              }
105cb93a386Sopenharmony_ci                              if (a->fLine != b->fLine) {
106cb93a386Sopenharmony_ci                                  return a->fLine < b->fLine;
107cb93a386Sopenharmony_ci                              }
108cb93a386Sopenharmony_ci                              if (a->name() != b->name()) {
109cb93a386Sopenharmony_ci                                  return a->name() < b->name();
110cb93a386Sopenharmony_ci                              }
111cb93a386Sopenharmony_ci                              return a->description() < b->description();
112cb93a386Sopenharmony_ci                          });
113cb93a386Sopenharmony_ci                for (const FunctionDeclaration* f : intrinsics) {
114cb93a386Sopenharmony_ci                    this->copyIntrinsicIfNeeded(*f);
115cb93a386Sopenharmony_ci                }
116cb93a386Sopenharmony_ci
117cb93a386Sopenharmony_ci                ThreadContext::SharedElements().push_back(found);
118cb93a386Sopenharmony_ci            }
119cb93a386Sopenharmony_ci        }
120cb93a386Sopenharmony_ci
121cb93a386Sopenharmony_ci        bool functionReturnsValue() const {
122cb93a386Sopenharmony_ci            return !fFunction.returnType().isVoid();
123cb93a386Sopenharmony_ci        }
124cb93a386Sopenharmony_ci
125cb93a386Sopenharmony_ci        bool visitExpression(Expression& expr) override {
126cb93a386Sopenharmony_ci            if (expr.is<FunctionCall>()) {
127cb93a386Sopenharmony_ci                const FunctionDeclaration& func = expr.as<FunctionCall>().function();
128cb93a386Sopenharmony_ci                if (func.isBuiltin()) {
129cb93a386Sopenharmony_ci                    if (func.intrinsicKind() == k_dFdy_IntrinsicKind) {
130cb93a386Sopenharmony_ci                        ThreadContext::Inputs().fUseFlipRTUniform = true;
131cb93a386Sopenharmony_ci                    }
132cb93a386Sopenharmony_ci                    if (func.definition()) {
133cb93a386Sopenharmony_ci                        fReferencedIntrinsics->insert(&func);
134cb93a386Sopenharmony_ci                    }
135cb93a386Sopenharmony_ci                    if (!fContext.fConfig->fIsBuiltinCode && fContext.fIntrinsics) {
136cb93a386Sopenharmony_ci                        this->copyIntrinsicIfNeeded(func);
137cb93a386Sopenharmony_ci                    }
138cb93a386Sopenharmony_ci                }
139cb93a386Sopenharmony_ci
140cb93a386Sopenharmony_ci            }
141cb93a386Sopenharmony_ci            return INHERITED::visitExpression(expr);
142cb93a386Sopenharmony_ci        }
143cb93a386Sopenharmony_ci
144cb93a386Sopenharmony_ci        bool visitStatement(Statement& stmt) override {
145cb93a386Sopenharmony_ci            switch (stmt.kind()) {
146cb93a386Sopenharmony_ci                case Statement::Kind::kVarDeclaration: {
147cb93a386Sopenharmony_ci                    // We count the number of slots used, but don't consider the precision of the
148cb93a386Sopenharmony_ci                    // base type. In practice, this reflects what GPUs really do pretty well.
149cb93a386Sopenharmony_ci                    // (i.e., RelaxedPrecision math doesn't mean your variable takes less space.)
150cb93a386Sopenharmony_ci                    // We also don't attempt to reclaim slots at the end of a Block.
151cb93a386Sopenharmony_ci                    size_t prevSlotsUsed = fSlotsUsed;
152cb93a386Sopenharmony_ci                    fSlotsUsed = SkSafeMath::Add(
153cb93a386Sopenharmony_ci                            fSlotsUsed, stmt.as<VarDeclaration>().var().type().slotCount());
154cb93a386Sopenharmony_ci                    // To avoid overzealous error reporting, only trigger the error at the first
155cb93a386Sopenharmony_ci                    // place where the stack limit is exceeded.
156cb93a386Sopenharmony_ci                    if (prevSlotsUsed < kVariableSlotLimit && fSlotsUsed >= kVariableSlotLimit) {
157cb93a386Sopenharmony_ci                        fContext.fErrors->error(stmt.fLine, "variable '" +
158cb93a386Sopenharmony_ci                                                            stmt.as<VarDeclaration>().var().name() +
159cb93a386Sopenharmony_ci                                                            "' exceeds the stack size limit");
160cb93a386Sopenharmony_ci                    }
161cb93a386Sopenharmony_ci                    break;
162cb93a386Sopenharmony_ci                }
163cb93a386Sopenharmony_ci                case Statement::Kind::kReturn: {
164cb93a386Sopenharmony_ci                    // Early returns from a vertex main() function will bypass sk_Position
165cb93a386Sopenharmony_ci                    // normalization, so SkASSERT that we aren't doing that. If this becomes an
166cb93a386Sopenharmony_ci                    // issue, we can add normalization before each return statement.
167cb93a386Sopenharmony_ci                    if (fContext.fConfig->fKind == ProgramKind::kVertex && fFunction.isMain()) {
168cb93a386Sopenharmony_ci                        fContext.fErrors->error(
169cb93a386Sopenharmony_ci                                stmt.fLine,
170cb93a386Sopenharmony_ci                                "early returns from vertex programs are not supported");
171cb93a386Sopenharmony_ci                    }
172cb93a386Sopenharmony_ci
173cb93a386Sopenharmony_ci                    // Verify that the return statement matches the function's return type.
174cb93a386Sopenharmony_ci                    ReturnStatement& returnStmt = stmt.as<ReturnStatement>();
175cb93a386Sopenharmony_ci                    if (returnStmt.expression()) {
176cb93a386Sopenharmony_ci                        if (this->functionReturnsValue()) {
177cb93a386Sopenharmony_ci                            // Coerce return expression to the function's return type.
178cb93a386Sopenharmony_ci                            returnStmt.setExpression(fFunction.returnType().coerceExpression(
179cb93a386Sopenharmony_ci                                    std::move(returnStmt.expression()), fContext));
180cb93a386Sopenharmony_ci                        } else {
181cb93a386Sopenharmony_ci                            // Returning something from a function with a void return type.
182cb93a386Sopenharmony_ci                            returnStmt.setExpression(nullptr);
183cb93a386Sopenharmony_ci                            fContext.fErrors->error(returnStmt.fLine,
184cb93a386Sopenharmony_ci                                                    "may not return a value from a void function");
185cb93a386Sopenharmony_ci                        }
186cb93a386Sopenharmony_ci                    } else {
187cb93a386Sopenharmony_ci                        if (this->functionReturnsValue()) {
188cb93a386Sopenharmony_ci                            // Returning nothing from a function with a non-void return type.
189cb93a386Sopenharmony_ci                            fContext.fErrors->error(returnStmt.fLine,
190cb93a386Sopenharmony_ci                                                    "expected function to return '" +
191cb93a386Sopenharmony_ci                                                    fFunction.returnType().displayName() + "'");
192cb93a386Sopenharmony_ci                        }
193cb93a386Sopenharmony_ci                    }
194cb93a386Sopenharmony_ci                    break;
195cb93a386Sopenharmony_ci                }
196cb93a386Sopenharmony_ci                case Statement::Kind::kDo:
197cb93a386Sopenharmony_ci                case Statement::Kind::kFor: {
198cb93a386Sopenharmony_ci                    ++fBreakableLevel;
199cb93a386Sopenharmony_ci                    ++fContinuableLevel.front();
200cb93a386Sopenharmony_ci                    bool result = INHERITED::visitStatement(stmt);
201cb93a386Sopenharmony_ci                    --fContinuableLevel.front();
202cb93a386Sopenharmony_ci                    --fBreakableLevel;
203cb93a386Sopenharmony_ci                    return result;
204cb93a386Sopenharmony_ci                }
205cb93a386Sopenharmony_ci                case Statement::Kind::kSwitch: {
206cb93a386Sopenharmony_ci                    ++fBreakableLevel;
207cb93a386Sopenharmony_ci                    fContinuableLevel.push_front(0);
208cb93a386Sopenharmony_ci                    bool result = INHERITED::visitStatement(stmt);
209cb93a386Sopenharmony_ci                    fContinuableLevel.pop_front();
210cb93a386Sopenharmony_ci                    --fBreakableLevel;
211cb93a386Sopenharmony_ci                    return result;
212cb93a386Sopenharmony_ci                }
213cb93a386Sopenharmony_ci                case Statement::Kind::kBreak:
214cb93a386Sopenharmony_ci                    if (fBreakableLevel == 0) {
215cb93a386Sopenharmony_ci                        fContext.fErrors->error(stmt.fLine,
216cb93a386Sopenharmony_ci                                                "break statement must be inside a loop or switch");
217cb93a386Sopenharmony_ci                    }
218cb93a386Sopenharmony_ci                    break;
219cb93a386Sopenharmony_ci                case Statement::Kind::kContinue:
220cb93a386Sopenharmony_ci                    if (fContinuableLevel.front() == 0) {
221cb93a386Sopenharmony_ci                        if (std::any_of(fContinuableLevel.begin(),
222cb93a386Sopenharmony_ci                                        fContinuableLevel.end(),
223cb93a386Sopenharmony_ci                                        [](int level) { return level > 0; })) {
224cb93a386Sopenharmony_ci                            fContext.fErrors->error(stmt.fLine,
225cb93a386Sopenharmony_ci                                                   "continue statement cannot be used in a switch");
226cb93a386Sopenharmony_ci                        } else {
227cb93a386Sopenharmony_ci                            fContext.fErrors->error(stmt.fLine,
228cb93a386Sopenharmony_ci                                                    "continue statement must be inside a loop");
229cb93a386Sopenharmony_ci                        }
230cb93a386Sopenharmony_ci                    }
231cb93a386Sopenharmony_ci                    break;
232cb93a386Sopenharmony_ci                default:
233cb93a386Sopenharmony_ci                    break;
234cb93a386Sopenharmony_ci            }
235cb93a386Sopenharmony_ci            return INHERITED::visitStatement(stmt);
236cb93a386Sopenharmony_ci        }
237cb93a386Sopenharmony_ci
238cb93a386Sopenharmony_ci    private:
239cb93a386Sopenharmony_ci        const Context& fContext;
240cb93a386Sopenharmony_ci        const FunctionDeclaration& fFunction;
241cb93a386Sopenharmony_ci        // which intrinsics have we encountered in this function
242cb93a386Sopenharmony_ci        IntrinsicSet* fReferencedIntrinsics;
243cb93a386Sopenharmony_ci        // how deeply nested we are in breakable constructs (for, do, switch).
244cb93a386Sopenharmony_ci        int fBreakableLevel = 0;
245cb93a386Sopenharmony_ci        // number of slots consumed by all variables declared in the function
246cb93a386Sopenharmony_ci        size_t fSlotsUsed = 0;
247cb93a386Sopenharmony_ci        // how deeply nested we are in continuable constructs (for, do).
248cb93a386Sopenharmony_ci        // We keep a stack (via a forward_list) in order to disallow continue inside of switch.
249cb93a386Sopenharmony_ci        std::forward_list<int> fContinuableLevel{0};
250cb93a386Sopenharmony_ci
251cb93a386Sopenharmony_ci        using INHERITED = ProgramWriter;
252cb93a386Sopenharmony_ci    };
253cb93a386Sopenharmony_ci
254cb93a386Sopenharmony_ci    IntrinsicSet referencedIntrinsics;
255cb93a386Sopenharmony_ci    Finalizer(context, function, &referencedIntrinsics).visitStatement(*body);
256cb93a386Sopenharmony_ci    if (function.isMain() && context.fConfig->fKind == ProgramKind::kVertex) {
257cb93a386Sopenharmony_ci        append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
258cb93a386Sopenharmony_ci    }
259cb93a386Sopenharmony_ci
260cb93a386Sopenharmony_ci    if (Analysis::CanExitWithoutReturningValue(function, *body)) {
261cb93a386Sopenharmony_ci        context.fErrors->error(function.fLine, "function '" + function.name() +
262cb93a386Sopenharmony_ci                                               "' can exit without returning a value");
263cb93a386Sopenharmony_ci    }
264cb93a386Sopenharmony_ci
265cb93a386Sopenharmony_ci    return std::make_unique<FunctionDefinition>(line, &function, builtin, std::move(body),
266cb93a386Sopenharmony_ci                                                std::move(referencedIntrinsics));
267cb93a386Sopenharmony_ci}
268cb93a386Sopenharmony_ci
269cb93a386Sopenharmony_ci}  // namespace SkSL
270