1/*
2 * Copyright 2021 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/lex/DFA.h"
9#include "src/sksl/lex/TransitionTable.h"
10
11#include <array>
12#include <bitset>
13#include <cassert>
14#include <cmath>
15#include <unordered_map>
16#include <unordered_set>
17#include <vector>
18
19namespace {
20
21// The number of bits to use per entry in our compact transition table. This is customizable:
22// - 1-bit: reasonable in theory. Doesn't actually pack many slices.
23// - 2-bit: best fit for our data. Packs extremely well.
24// - 4-bit: packs all but one slice, but doesn't save as much space overall.
25// - 8-bit: way too large (an 8-bit LUT plus an 8-bit data table is as big as a 16-bit table)
26// Other values don't divide cleanly into a byte and do not work.
27constexpr int kNumBits = 2;
28
29// These values are derived from kNumBits and shouldn't need to change.
30constexpr int kNumValues = (1 << kNumBits) - 1;
31constexpr int kDataPerByte = 8 / kNumBits;
32
33enum IndexType {
34    kZero = 0,
35    kFullEntry,
36    kCompactEntry,
37};
38struct IndexEntry {
39    IndexType type;
40    int pos;
41};
42struct CompactEntry {
43    std::array<int, kNumValues> v = {};
44    std::vector<int> data;
45};
46struct FullEntry {
47    std::vector<int> data;
48};
49
50using TransitionSet = std::unordered_set<int>;
51
52static int add_compact_entry(const TransitionSet& transitionSet,
53                             const std::vector<int>& data,
54                             std::vector<CompactEntry>* entries) {
55    // Create a compact entry with the unique values from the transition set, padded out with zeros
56    // and sorted.
57    CompactEntry result{};
58    assert(transitionSet.size() <= result.v.size());
59    std::copy(transitionSet.begin(), transitionSet.end(), result.v.begin());
60    std::sort(result.v.begin(), result.v.end());
61
62    // Create a mapping from real values to small values. (0 -> 0, v[0] -> 1, v[1] -> 2, v[2] -> 3)
63    std::unordered_map<int, int> translationTable;
64    for (size_t index = 0; index < result.v.size(); ++index) {
65        translationTable[result.v[index]] = 1 + index;
66    }
67    translationTable[0] = 0;
68
69    // Convert the real values into small values.
70    for (size_t index = 0; index < data.size(); ++index) {
71        int value = data[index];
72        assert(translationTable.find(value) != translationTable.end());
73        result.data.push_back(translationTable[value]);
74    }
75
76    // Look for an existing entry that exactly matches this one.
77    for (size_t index = 0; index < entries->size(); ++index) {
78        if (entries->at(index).v == result.v && entries->at(index).data == result.data) {
79            return index;
80        }
81    }
82
83    // Add this as a new entry.
84    entries->push_back(std::move(result));
85    return (int)(entries->size() - 1);
86}
87
88static int add_full_entry(const TransitionSet& transitionMap,
89                          const std::vector<int>& data,
90                          std::vector<FullEntry>* entries) {
91    // Create a full entry with this data.
92    FullEntry result{};
93    result.data = std::vector<int>(data.begin(), data.end());
94
95    // Look for an existing entry that exactly matches this one.
96    for (size_t index = 0; index < entries->size(); ++index) {
97        if (entries->at(index).data == result.data) {
98            return index;
99        }
100    }
101
102    // Add this as a new entry.
103    entries->push_back(std::move(result));
104    return (int)(entries->size() - 1);
105}
106
107}  // namespace
108
109void WriteTransitionTable(std::ofstream& out, const DFA& dfa, size_t states) {
110    int numTransitions = dfa.fTransitions.size();
111
112    // Assemble our compact and full data tables, and an index into them.
113    std::vector<CompactEntry> compactEntries;
114    std::vector<FullEntry> fullEntries;
115    std::vector<IndexEntry> indices;
116    for (size_t s = 0; s < states; ++s) {
117        // Copy all the transitions for this state into a flat array, and into a histogram (counting
118        // the number of unique state-transition values). Most states only transition to a few
119        // possible new states.
120        TransitionSet transitionSet;
121        std::vector<int> data(numTransitions);
122        for (int t = 0; t < numTransitions; ++t) {
123            if ((size_t) t < dfa.fTransitions.size() && s < dfa.fTransitions[t].size()) {
124                int value = dfa.fTransitions[t][s];
125                assert(value >= 0 && value < (int)states);
126                data[t] = value;
127                transitionSet.insert(value);
128            }
129        }
130
131        transitionSet.erase(0);
132        if (transitionSet.empty()) {
133            // This transition table was completely empty (every value was zero). No data needed;
134            // zero pages are handled as a special index type.
135            indices.push_back(IndexEntry{kZero, 0});
136        } else if (transitionSet.size() <= kNumValues) {
137            // This table only contained a small number of unique nonzero values.
138            // Use a compact representation that squishes each value down to a few bits.
139            int index = add_compact_entry(transitionSet, data, &compactEntries);
140            indices.push_back(IndexEntry{kCompactEntry, index});
141        } else {
142            // This table contained a large number of values. We can't compact it.
143            int index = add_full_entry(transitionSet, data, &fullEntries);
144            indices.push_back(IndexEntry{kFullEntry, index});
145        }
146    }
147
148    // Find the largest value for each compact-entry slot.
149    int maxValue[kNumValues] = {};
150    for (const CompactEntry& entry : compactEntries) {
151        for (int index=0; index < kNumValues; ++index) {
152            maxValue[index] = std::max(maxValue[index], entry.v[index]);
153        }
154    }
155
156    // Emit all the structs our transition table will use.
157    out << "struct IndexEntry {\n"
158        << "    uint16_t type : 2;\n"
159        << "    uint16_t pos : 14;\n"
160        << "};\n"
161        << "struct FullEntry {\n"
162        << "    State data[" << numTransitions << "];\n"
163        << "};\n";
164
165    // Emit the compact-entry structure; minimize the number of bits needed per value.
166    out << "struct CompactEntry {\n";
167    for (int index=0; index < kNumValues; ++index) {
168        if (maxValue[index] > 0) {
169            out << "    State v" << index << " : " << int(std::ceil(std::log2(maxValue[index])))
170                << ";\n";
171        }
172    }
173
174    out << "    uint8_t data[" << std::ceil(float(numTransitions) / float(kDataPerByte)) << "];\n"
175        << "};\n";
176
177    // Emit the full-table data.
178    out << "static constexpr FullEntry kFull[] = {\n";
179    for (const FullEntry& entry : fullEntries) {
180        out << "    {";
181        for (int value : entry.data) {
182            out << value << ", ";
183        }
184        out << "},\n";
185    }
186    out << "};\n";
187
188    // Emit the compact-table data.
189    out << "static constexpr CompactEntry kCompact[] = {\n";
190    for (const CompactEntry& entry : compactEntries) {
191        out << "    {";
192        for (int index=0; index < kNumValues; ++index) {
193            if (maxValue[index] > 0) {
194                out << entry.v[index] << ", ";
195            }
196        }
197        out << "{";
198        unsigned int shiftBits = 0, combinedBits = 0;
199        for (int index = 0; index < numTransitions; index++) {
200            combinedBits |= entry.data[index] << shiftBits;
201            shiftBits += kNumBits;
202            assert(shiftBits <= 8);
203            if (shiftBits == 8) {
204                out << combinedBits << ", ";
205                shiftBits = 0;
206                combinedBits = 0;
207            }
208        }
209        if (shiftBits > 0) {
210            // Flush any partial values.
211            out << combinedBits;
212        }
213        out << "}},\n";
214    }
215    out << "};\n"
216        << "static constexpr IndexEntry kIndices[] = {\n";
217    for (const IndexEntry& entry : indices) {
218        out << "    {" << entry.type << ", " << entry.pos << "},\n";
219    }
220    out << "};\n"
221        << "State get_transition(int transition, int state) {\n"
222        << "    IndexEntry index = kIndices[state];\n"
223        << "    if (index.type == 0) { return 0; }\n"
224        << "    if (index.type == 1) { return kFull[index.pos].data[transition]; }\n"
225        << "    const CompactEntry& entry = kCompact[index.pos];\n"
226        << "    int value = entry.data[transition >> " << std::log2(kDataPerByte) << "];\n"
227        << "    value >>= " << kNumBits << " * (transition & " << kDataPerByte - 1 << ");\n"
228        << "    value &= " << kNumValues << ";\n"
229        << "    State table[] = {0";
230
231    for (int index=0; index < kNumValues; ++index) {
232        if (maxValue[index] > 0) {
233            out << ", entry.v" << index;
234        } else {
235            out << ", 0";
236        }
237    }
238
239    out << "};\n"
240        << "    return table[value];\n"
241        << "}\n";
242}
243