1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#ifndef SKIASL_MEMORYLAYOUT
9#define SKIASL_MEMORYLAYOUT
10
11#include <algorithm>
12
13#include "src/sksl/ir/SkSLType.h"
14
15namespace SkSL {
16
17class MemoryLayout {
18public:
19    enum Standard {
20        k140_Standard,
21        k430_Standard,
22        kMetal_Standard
23    };
24
25    MemoryLayout(Standard std)
26    : fStd(std) {}
27
28    static size_t vector_alignment(size_t componentSize, int columns) {
29        return componentSize * (columns + columns % 2);
30    }
31
32    /**
33     * Rounds up to the nearest multiple of 16 if in std140, otherwise returns the parameter
34     * unchanged (std140 requires various things to be rounded up to the nearest multiple of 16,
35     * std430 does not).
36     */
37    size_t roundUpIfNeeded(size_t raw) const {
38        switch (fStd) {
39            case k140_Standard: return (raw + 15) & ~15;
40            case k430_Standard: return raw;
41            case kMetal_Standard: return raw;
42        }
43        SkUNREACHABLE;
44    }
45
46    /**
47     * Returns a type's required alignment when used as a standalone variable.
48     */
49    size_t alignment(const Type& type) const {
50        // See OpenGL Spec 7.6.2.2 Standard Uniform Block Layout
51        switch (type.typeKind()) {
52            case Type::TypeKind::kScalar:
53                return this->size(type);
54            case Type::TypeKind::kVector:
55                return vector_alignment(this->size(type.componentType()), type.columns());
56            case Type::TypeKind::kMatrix:
57                return this->roundUpIfNeeded(vector_alignment(this->size(type.componentType()),
58                                                              type.rows()));
59            case Type::TypeKind::kArray:
60                return this->roundUpIfNeeded(this->alignment(type.componentType()));
61            case Type::TypeKind::kStruct: {
62                size_t result = 0;
63                for (const auto& f : type.fields()) {
64                    size_t alignment = this->alignment(*f.fType);
65                    if (alignment > result) {
66                        result = alignment;
67                    }
68                }
69                return this->roundUpIfNeeded(result);
70            }
71            default:
72                SK_ABORT("cannot determine size of type %s", String(type.name()).c_str());
73        }
74    }
75
76    /**
77     * For matrices and arrays, returns the number of bytes from the start of one entry (row, in
78     * the case of matrices) to the start of the next.
79     */
80    size_t stride(const Type& type) const {
81        switch (type.typeKind()) {
82            case Type::TypeKind::kMatrix: {
83                size_t base = vector_alignment(this->size(type.componentType()), type.rows());
84                return this->roundUpIfNeeded(base);
85            }
86            case Type::TypeKind::kArray: {
87                int stride = this->size(type.componentType());
88                if (stride > 0) {
89                    int align = this->alignment(type.componentType());
90                    stride += align - 1;
91                    stride -= stride % align;
92                    stride = this->roundUpIfNeeded(stride);
93                }
94                return stride;
95            }
96            default:
97                SK_ABORT("type does not have a stride");
98        }
99    }
100
101    /**
102     * Returns the size of a type in bytes.
103     */
104    size_t size(const Type& type) const {
105        switch (type.typeKind()) {
106            case Type::TypeKind::kScalar:
107                if (type.isBoolean()) {
108                    return 1;
109                }
110                if (fStd == kMetal_Standard && !type.highPrecision() && type.isNumber()) {
111                    return 2;
112                }
113                return 4;
114            case Type::TypeKind::kVector:
115                if (fStd == kMetal_Standard && type.columns() == 3) {
116                    return 4 * this->size(type.componentType());
117                }
118                return type.columns() * this->size(type.componentType());
119            case Type::TypeKind::kMatrix: // fall through
120            case Type::TypeKind::kArray:
121                return type.columns() * this->stride(type);
122            case Type::TypeKind::kStruct: {
123                size_t total = 0;
124                for (const auto& f : type.fields()) {
125                    size_t alignment = this->alignment(*f.fType);
126                    if (total % alignment != 0) {
127                        total += alignment - total % alignment;
128                    }
129                    SkASSERT(total % alignment == 0);
130                    total += this->size(*f.fType);
131                }
132                size_t alignment = this->alignment(type);
133                SkASSERT(!type.fields().size() ||
134                       (0 == alignment % this->alignment(*type.fields()[0].fType)));
135                return (total + alignment - 1) & ~(alignment - 1);
136            }
137            default:
138                SK_ABORT("cannot determine size of type %s", String(type.name()).c_str());
139        }
140    }
141
142    /**
143     * Not all types are compatible with memory layout.
144     */
145    static size_t LayoutIsSupported(const Type& type) {
146        switch (type.typeKind()) {
147            case Type::TypeKind::kScalar:
148            case Type::TypeKind::kVector:
149            case Type::TypeKind::kMatrix:
150                return true;
151
152            case Type::TypeKind::kArray:
153                return LayoutIsSupported(type.componentType());
154
155            case Type::TypeKind::kStruct:
156                return std::all_of(
157                        type.fields().begin(), type.fields().end(),
158                        [](const Type::Field& f) { return LayoutIsSupported(*f.fType); });
159
160            default:
161                return false;
162        }
163    }
164
165    const Standard fStd;
166};
167
168}  // namespace SkSL
169
170#endif
171