1bf215546Sopenharmony_ci#
2bf215546Sopenharmony_ci# Copyright (C) 2014 Intel Corporation
3bf215546Sopenharmony_ci#
4bf215546Sopenharmony_ci# Permission is hereby granted, free of charge, to any person obtaining a
5bf215546Sopenharmony_ci# copy of this software and associated documentation files (the "Software"),
6bf215546Sopenharmony_ci# to deal in the Software without restriction, including without limitation
7bf215546Sopenharmony_ci# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8bf215546Sopenharmony_ci# and/or sell copies of the Software, and to permit persons to whom the
9bf215546Sopenharmony_ci# Software is furnished to do so, subject to the following conditions:
10bf215546Sopenharmony_ci#
11bf215546Sopenharmony_ci# The above copyright notice and this permission notice (including the next
12bf215546Sopenharmony_ci# paragraph) shall be included in all copies or substantial portions of the
13bf215546Sopenharmony_ci# Software.
14bf215546Sopenharmony_ci#
15bf215546Sopenharmony_ci# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16bf215546Sopenharmony_ci# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17bf215546Sopenharmony_ci# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18bf215546Sopenharmony_ci# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19bf215546Sopenharmony_ci# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20bf215546Sopenharmony_ci# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21bf215546Sopenharmony_ci# IN THE SOFTWARE.
22bf215546Sopenharmony_ci#
23bf215546Sopenharmony_ci# Authors:
24bf215546Sopenharmony_ci#    Jason Ekstrand (jason@jlekstrand.net)
25bf215546Sopenharmony_ci
26bf215546Sopenharmony_ciimport ast
27bf215546Sopenharmony_cifrom collections import defaultdict
28bf215546Sopenharmony_ciimport itertools
29bf215546Sopenharmony_ciimport struct
30bf215546Sopenharmony_ciimport sys
31bf215546Sopenharmony_ciimport mako.template
32bf215546Sopenharmony_ciimport re
33bf215546Sopenharmony_ciimport traceback
34bf215546Sopenharmony_ci
35bf215546Sopenharmony_cifrom nir_opcodes import opcodes, type_sizes
36bf215546Sopenharmony_ci
37bf215546Sopenharmony_ci# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
38bf215546Sopenharmony_cinir_search_max_comm_ops = 8
39bf215546Sopenharmony_ci
40bf215546Sopenharmony_ci# These opcodes are only employed by nir_search.  This provides a mapping from
41bf215546Sopenharmony_ci# opcode to destination type.
42bf215546Sopenharmony_ciconv_opcode_types = {
43bf215546Sopenharmony_ci    'i2f' : 'float',
44bf215546Sopenharmony_ci    'u2f' : 'float',
45bf215546Sopenharmony_ci    'f2f' : 'float',
46bf215546Sopenharmony_ci    'f2u' : 'uint',
47bf215546Sopenharmony_ci    'f2i' : 'int',
48bf215546Sopenharmony_ci    'u2u' : 'uint',
49bf215546Sopenharmony_ci    'i2i' : 'int',
50bf215546Sopenharmony_ci    'b2f' : 'float',
51bf215546Sopenharmony_ci    'b2i' : 'int',
52bf215546Sopenharmony_ci    'i2b' : 'bool',
53bf215546Sopenharmony_ci    'f2b' : 'bool',
54bf215546Sopenharmony_ci}
55bf215546Sopenharmony_ci
56bf215546Sopenharmony_cidef get_cond_index(conds, cond):
57bf215546Sopenharmony_ci    if cond:
58bf215546Sopenharmony_ci        if cond in conds:
59bf215546Sopenharmony_ci            return conds[cond]
60bf215546Sopenharmony_ci        else:
61bf215546Sopenharmony_ci            cond_index = len(conds)
62bf215546Sopenharmony_ci            conds[cond] = cond_index
63bf215546Sopenharmony_ci            return cond_index
64bf215546Sopenharmony_ci    else:
65bf215546Sopenharmony_ci        return -1
66bf215546Sopenharmony_ci
67bf215546Sopenharmony_cidef get_c_opcode(op):
68bf215546Sopenharmony_ci      if op in conv_opcode_types:
69bf215546Sopenharmony_ci         return 'nir_search_op_' + op
70bf215546Sopenharmony_ci      else:
71bf215546Sopenharmony_ci         return 'nir_op_' + op
72bf215546Sopenharmony_ci
73bf215546Sopenharmony_ci_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
74bf215546Sopenharmony_ci
75bf215546Sopenharmony_cidef type_bits(type_str):
76bf215546Sopenharmony_ci   m = _type_re.match(type_str)
77bf215546Sopenharmony_ci   assert m.group('type')
78bf215546Sopenharmony_ci
79bf215546Sopenharmony_ci   if m.group('bits') is None:
80bf215546Sopenharmony_ci      return 0
81bf215546Sopenharmony_ci   else:
82bf215546Sopenharmony_ci      return int(m.group('bits'))
83bf215546Sopenharmony_ci
84bf215546Sopenharmony_ci# Represents a set of variables, each with a unique id
85bf215546Sopenharmony_ciclass VarSet(object):
86bf215546Sopenharmony_ci   def __init__(self):
87bf215546Sopenharmony_ci      self.names = {}
88bf215546Sopenharmony_ci      self.ids = itertools.count()
89bf215546Sopenharmony_ci      self.immutable = False;
90bf215546Sopenharmony_ci
91bf215546Sopenharmony_ci   def __getitem__(self, name):
92bf215546Sopenharmony_ci      if name not in self.names:
93bf215546Sopenharmony_ci         assert not self.immutable, "Unknown replacement variable: " + name
94bf215546Sopenharmony_ci         self.names[name] = next(self.ids)
95bf215546Sopenharmony_ci
96bf215546Sopenharmony_ci      return self.names[name]
97bf215546Sopenharmony_ci
98bf215546Sopenharmony_ci   def lock(self):
99bf215546Sopenharmony_ci      self.immutable = True
100bf215546Sopenharmony_ci
101bf215546Sopenharmony_ciclass SearchExpression(object):
102bf215546Sopenharmony_ci   def __init__(self, expr):
103bf215546Sopenharmony_ci      self.opcode = expr[0]
104bf215546Sopenharmony_ci      self.sources = expr[1:]
105bf215546Sopenharmony_ci      self.ignore_exact = False
106bf215546Sopenharmony_ci
107bf215546Sopenharmony_ci   @staticmethod
108bf215546Sopenharmony_ci   def create(val):
109bf215546Sopenharmony_ci      if isinstance(val, tuple):
110bf215546Sopenharmony_ci         return SearchExpression(val)
111bf215546Sopenharmony_ci      else:
112bf215546Sopenharmony_ci         assert(isinstance(val, SearchExpression))
113bf215546Sopenharmony_ci         return val
114bf215546Sopenharmony_ci
115bf215546Sopenharmony_ci   def __repr__(self):
116bf215546Sopenharmony_ci      l = [self.opcode, *self.sources]
117bf215546Sopenharmony_ci      if self.ignore_exact:
118bf215546Sopenharmony_ci         l.append('ignore_exact')
119bf215546Sopenharmony_ci      return repr((*l,))
120bf215546Sopenharmony_ci
121bf215546Sopenharmony_ciclass Value(object):
122bf215546Sopenharmony_ci   @staticmethod
123bf215546Sopenharmony_ci   def create(val, name_base, varset, algebraic_pass):
124bf215546Sopenharmony_ci      if isinstance(val, bytes):
125bf215546Sopenharmony_ci         val = val.decode('utf-8')
126bf215546Sopenharmony_ci
127bf215546Sopenharmony_ci      if isinstance(val, tuple) or isinstance(val, SearchExpression):
128bf215546Sopenharmony_ci         return Expression(val, name_base, varset, algebraic_pass)
129bf215546Sopenharmony_ci      elif isinstance(val, Expression):
130bf215546Sopenharmony_ci         return val
131bf215546Sopenharmony_ci      elif isinstance(val, str):
132bf215546Sopenharmony_ci         return Variable(val, name_base, varset, algebraic_pass)
133bf215546Sopenharmony_ci      elif isinstance(val, (bool, float, int)):
134bf215546Sopenharmony_ci         return Constant(val, name_base)
135bf215546Sopenharmony_ci
136bf215546Sopenharmony_ci   def __init__(self, val, name, type_str):
137bf215546Sopenharmony_ci      self.in_val = str(val)
138bf215546Sopenharmony_ci      self.name = name
139bf215546Sopenharmony_ci      self.type_str = type_str
140bf215546Sopenharmony_ci
141bf215546Sopenharmony_ci   def __str__(self):
142bf215546Sopenharmony_ci      return self.in_val
143bf215546Sopenharmony_ci
144bf215546Sopenharmony_ci   def get_bit_size(self):
145bf215546Sopenharmony_ci      """Get the physical bit-size that has been chosen for this value, or if
146bf215546Sopenharmony_ci      there is none, the canonical value which currently represents this
147bf215546Sopenharmony_ci      bit-size class. Variables will be preferred, i.e. if there are any
148bf215546Sopenharmony_ci      variables in the equivalence class, the canonical value will be a
149bf215546Sopenharmony_ci      variable. We do this since we'll need to know which variable each value
150bf215546Sopenharmony_ci      is equivalent to when constructing the replacement expression. This is
151bf215546Sopenharmony_ci      the "find" part of the union-find algorithm.
152bf215546Sopenharmony_ci      """
153bf215546Sopenharmony_ci      bit_size = self
154bf215546Sopenharmony_ci
155bf215546Sopenharmony_ci      while isinstance(bit_size, Value):
156bf215546Sopenharmony_ci         if bit_size._bit_size is None:
157bf215546Sopenharmony_ci            break
158bf215546Sopenharmony_ci         bit_size = bit_size._bit_size
159bf215546Sopenharmony_ci
160bf215546Sopenharmony_ci      if bit_size is not self:
161bf215546Sopenharmony_ci         self._bit_size = bit_size
162bf215546Sopenharmony_ci      return bit_size
163bf215546Sopenharmony_ci
164bf215546Sopenharmony_ci   def set_bit_size(self, other):
165bf215546Sopenharmony_ci      """Make self.get_bit_size() return what other.get_bit_size() return
166bf215546Sopenharmony_ci      before calling this, or just "other" if it's a concrete bit-size. This is
167bf215546Sopenharmony_ci      the "union" part of the union-find algorithm.
168bf215546Sopenharmony_ci      """
169bf215546Sopenharmony_ci
170bf215546Sopenharmony_ci      self_bit_size = self.get_bit_size()
171bf215546Sopenharmony_ci      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
172bf215546Sopenharmony_ci
173bf215546Sopenharmony_ci      if self_bit_size == other_bit_size:
174bf215546Sopenharmony_ci         return
175bf215546Sopenharmony_ci
176bf215546Sopenharmony_ci      self_bit_size._bit_size = other_bit_size
177bf215546Sopenharmony_ci
178bf215546Sopenharmony_ci   @property
179bf215546Sopenharmony_ci   def type_enum(self):
180bf215546Sopenharmony_ci      return "nir_search_value_" + self.type_str
181bf215546Sopenharmony_ci
182bf215546Sopenharmony_ci   @property
183bf215546Sopenharmony_ci   def c_bit_size(self):
184bf215546Sopenharmony_ci      bit_size = self.get_bit_size()
185bf215546Sopenharmony_ci      if isinstance(bit_size, int):
186bf215546Sopenharmony_ci         return bit_size
187bf215546Sopenharmony_ci      elif isinstance(bit_size, Variable):
188bf215546Sopenharmony_ci         return -bit_size.index - 1
189bf215546Sopenharmony_ci      else:
190bf215546Sopenharmony_ci         # If the bit-size class is neither a variable, nor an actual bit-size, then
191bf215546Sopenharmony_ci         # - If it's in the search expression, we don't need to check anything
192bf215546Sopenharmony_ci         # - If it's in the replace expression, either it's ambiguous (in which
193bf215546Sopenharmony_ci         # case we'd reject it), or it equals the bit-size of the search value
194bf215546Sopenharmony_ci         # We represent these cases with a 0 bit-size.
195bf215546Sopenharmony_ci         return 0
196bf215546Sopenharmony_ci
197bf215546Sopenharmony_ci   __template = mako.template.Template("""   { .${val.type_str} = {
198bf215546Sopenharmony_ci      { ${val.type_enum}, ${val.c_bit_size} },
199bf215546Sopenharmony_ci% if isinstance(val, Constant):
200bf215546Sopenharmony_ci      ${val.type()}, { ${val.hex()} /* ${val.value} */ },
201bf215546Sopenharmony_ci% elif isinstance(val, Variable):
202bf215546Sopenharmony_ci      ${val.index}, /* ${val.var_name} */
203bf215546Sopenharmony_ci      ${'true' if val.is_constant else 'false'},
204bf215546Sopenharmony_ci      ${val.type() or 'nir_type_invalid' },
205bf215546Sopenharmony_ci      ${val.cond_index},
206bf215546Sopenharmony_ci      ${val.swizzle()},
207bf215546Sopenharmony_ci% elif isinstance(val, Expression):
208bf215546Sopenharmony_ci      ${'true' if val.inexact else 'false'},
209bf215546Sopenharmony_ci      ${'true' if val.exact else 'false'},
210bf215546Sopenharmony_ci      ${'true' if val.ignore_exact else 'false'},
211bf215546Sopenharmony_ci      ${val.c_opcode()},
212bf215546Sopenharmony_ci      ${val.comm_expr_idx}, ${val.comm_exprs},
213bf215546Sopenharmony_ci      { ${', '.join(src.array_index for src in val.sources)} },
214bf215546Sopenharmony_ci      ${val.cond_index},
215bf215546Sopenharmony_ci% endif
216bf215546Sopenharmony_ci   } },
217bf215546Sopenharmony_ci""")
218bf215546Sopenharmony_ci
219bf215546Sopenharmony_ci   def render(self, cache):
220bf215546Sopenharmony_ci      struct_init = self.__template.render(val=self,
221bf215546Sopenharmony_ci                                           Constant=Constant,
222bf215546Sopenharmony_ci                                           Variable=Variable,
223bf215546Sopenharmony_ci                                           Expression=Expression)
224bf215546Sopenharmony_ci      if struct_init in cache:
225bf215546Sopenharmony_ci         # If it's in the cache, register a name remap in the cache and render
226bf215546Sopenharmony_ci         # only a comment saying it's been remapped
227bf215546Sopenharmony_ci         self.array_index = cache[struct_init]
228bf215546Sopenharmony_ci         return "   /* {} -> {} in the cache */\n".format(self.name,
229bf215546Sopenharmony_ci                                                       cache[struct_init])
230bf215546Sopenharmony_ci      else:
231bf215546Sopenharmony_ci         self.array_index = str(cache["next_index"])
232bf215546Sopenharmony_ci         cache[struct_init] = self.array_index
233bf215546Sopenharmony_ci         cache["next_index"] += 1
234bf215546Sopenharmony_ci         return struct_init
235bf215546Sopenharmony_ci
236bf215546Sopenharmony_ci_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
237bf215546Sopenharmony_ci
238bf215546Sopenharmony_ciclass Constant(Value):
239bf215546Sopenharmony_ci   def __init__(self, val, name):
240bf215546Sopenharmony_ci      Value.__init__(self, val, name, "constant")
241bf215546Sopenharmony_ci
242bf215546Sopenharmony_ci      if isinstance(val, (str)):
243bf215546Sopenharmony_ci         m = _constant_re.match(val)
244bf215546Sopenharmony_ci         self.value = ast.literal_eval(m.group('value'))
245bf215546Sopenharmony_ci         self._bit_size = int(m.group('bits')) if m.group('bits') else None
246bf215546Sopenharmony_ci      else:
247bf215546Sopenharmony_ci         self.value = val
248bf215546Sopenharmony_ci         self._bit_size = None
249bf215546Sopenharmony_ci
250bf215546Sopenharmony_ci      if isinstance(self.value, bool):
251bf215546Sopenharmony_ci         assert self._bit_size is None or self._bit_size == 1
252bf215546Sopenharmony_ci         self._bit_size = 1
253bf215546Sopenharmony_ci
254bf215546Sopenharmony_ci   def hex(self):
255bf215546Sopenharmony_ci      if isinstance(self.value, (bool)):
256bf215546Sopenharmony_ci         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
257bf215546Sopenharmony_ci      if isinstance(self.value, int):
258bf215546Sopenharmony_ci         return hex(self.value)
259bf215546Sopenharmony_ci      elif isinstance(self.value, float):
260bf215546Sopenharmony_ci         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
261bf215546Sopenharmony_ci      else:
262bf215546Sopenharmony_ci         assert False
263bf215546Sopenharmony_ci
264bf215546Sopenharmony_ci   def type(self):
265bf215546Sopenharmony_ci      if isinstance(self.value, (bool)):
266bf215546Sopenharmony_ci         return "nir_type_bool"
267bf215546Sopenharmony_ci      elif isinstance(self.value, int):
268bf215546Sopenharmony_ci         return "nir_type_int"
269bf215546Sopenharmony_ci      elif isinstance(self.value, float):
270bf215546Sopenharmony_ci         return "nir_type_float"
271bf215546Sopenharmony_ci
272bf215546Sopenharmony_ci   def equivalent(self, other):
273bf215546Sopenharmony_ci      """Check that two constants are equivalent.
274bf215546Sopenharmony_ci
275bf215546Sopenharmony_ci      This is check is much weaker than equality.  One generally cannot be
276bf215546Sopenharmony_ci      used in place of the other.  Using this implementation for the __eq__
277bf215546Sopenharmony_ci      will break BitSizeValidator.
278bf215546Sopenharmony_ci
279bf215546Sopenharmony_ci      """
280bf215546Sopenharmony_ci      if not isinstance(other, type(self)):
281bf215546Sopenharmony_ci         return False
282bf215546Sopenharmony_ci
283bf215546Sopenharmony_ci      return self.value == other.value
284bf215546Sopenharmony_ci
285bf215546Sopenharmony_ci# The $ at the end forces there to be an error if any part of the string
286bf215546Sopenharmony_ci# doesn't match one of the field patterns.
287bf215546Sopenharmony_ci_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
288bf215546Sopenharmony_ci                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
289bf215546Sopenharmony_ci                          r"(?P<cond>\([^\)]+\))?"
290bf215546Sopenharmony_ci                          r"(?P<swiz>\.[xyzw]+)?"
291bf215546Sopenharmony_ci                          r"$")
292bf215546Sopenharmony_ci
293bf215546Sopenharmony_ciclass Variable(Value):
294bf215546Sopenharmony_ci   def __init__(self, val, name, varset, algebraic_pass):
295bf215546Sopenharmony_ci      Value.__init__(self, val, name, "variable")
296bf215546Sopenharmony_ci
297bf215546Sopenharmony_ci      m = _var_name_re.match(val)
298bf215546Sopenharmony_ci      assert m and m.group('name') is not None, \
299bf215546Sopenharmony_ci            "Malformed variable name \"{}\".".format(val)
300bf215546Sopenharmony_ci
301bf215546Sopenharmony_ci      self.var_name = m.group('name')
302bf215546Sopenharmony_ci
303bf215546Sopenharmony_ci      # Prevent common cases where someone puts quotes around a literal
304bf215546Sopenharmony_ci      # constant.  If we want to support names that have numeric or
305bf215546Sopenharmony_ci      # punctuation characters, we can me the first assertion more flexible.
306bf215546Sopenharmony_ci      assert self.var_name.isalpha()
307bf215546Sopenharmony_ci      assert self.var_name != 'True'
308bf215546Sopenharmony_ci      assert self.var_name != 'False'
309bf215546Sopenharmony_ci
310bf215546Sopenharmony_ci      self.is_constant = m.group('const') is not None
311bf215546Sopenharmony_ci      self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond'))
312bf215546Sopenharmony_ci      self.required_type = m.group('type')
313bf215546Sopenharmony_ci      self._bit_size = int(m.group('bits')) if m.group('bits') else None
314bf215546Sopenharmony_ci      self.swiz = m.group('swiz')
315bf215546Sopenharmony_ci
316bf215546Sopenharmony_ci      if self.required_type == 'bool':
317bf215546Sopenharmony_ci         if self._bit_size is not None:
318bf215546Sopenharmony_ci            assert self._bit_size in type_sizes(self.required_type)
319bf215546Sopenharmony_ci         else:
320bf215546Sopenharmony_ci            self._bit_size = 1
321bf215546Sopenharmony_ci
322bf215546Sopenharmony_ci      if self.required_type is not None:
323bf215546Sopenharmony_ci         assert self.required_type in ('float', 'bool', 'int', 'uint')
324bf215546Sopenharmony_ci
325bf215546Sopenharmony_ci      self.index = varset[self.var_name]
326bf215546Sopenharmony_ci
327bf215546Sopenharmony_ci   def type(self):
328bf215546Sopenharmony_ci      if self.required_type == 'bool':
329bf215546Sopenharmony_ci         return "nir_type_bool"
330bf215546Sopenharmony_ci      elif self.required_type in ('int', 'uint'):
331bf215546Sopenharmony_ci         return "nir_type_int"
332bf215546Sopenharmony_ci      elif self.required_type == 'float':
333bf215546Sopenharmony_ci         return "nir_type_float"
334bf215546Sopenharmony_ci
335bf215546Sopenharmony_ci   def equivalent(self, other):
336bf215546Sopenharmony_ci      """Check that two variables are equivalent.
337bf215546Sopenharmony_ci
338bf215546Sopenharmony_ci      This is check is much weaker than equality.  One generally cannot be
339bf215546Sopenharmony_ci      used in place of the other.  Using this implementation for the __eq__
340bf215546Sopenharmony_ci      will break BitSizeValidator.
341bf215546Sopenharmony_ci
342bf215546Sopenharmony_ci      """
343bf215546Sopenharmony_ci      if not isinstance(other, type(self)):
344bf215546Sopenharmony_ci         return False
345bf215546Sopenharmony_ci
346bf215546Sopenharmony_ci      return self.index == other.index
347bf215546Sopenharmony_ci
348bf215546Sopenharmony_ci   def swizzle(self):
349bf215546Sopenharmony_ci      if self.swiz is not None:
350bf215546Sopenharmony_ci         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
351bf215546Sopenharmony_ci                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
352bf215546Sopenharmony_ci                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
353bf215546Sopenharmony_ci                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
354bf215546Sopenharmony_ci                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
355bf215546Sopenharmony_ci         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
356bf215546Sopenharmony_ci      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
357bf215546Sopenharmony_ci
358bf215546Sopenharmony_ci_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
359bf215546Sopenharmony_ci                        r"(?P<cond>\([^\)]+\))?")
360bf215546Sopenharmony_ci
361bf215546Sopenharmony_ciclass Expression(Value):
362bf215546Sopenharmony_ci   def __init__(self, expr, name_base, varset, algebraic_pass):
363bf215546Sopenharmony_ci      Value.__init__(self, expr, name_base, "expression")
364bf215546Sopenharmony_ci
365bf215546Sopenharmony_ci      expr = SearchExpression.create(expr)
366bf215546Sopenharmony_ci
367bf215546Sopenharmony_ci      m = _opcode_re.match(expr.opcode)
368bf215546Sopenharmony_ci      assert m and m.group('opcode') is not None
369bf215546Sopenharmony_ci
370bf215546Sopenharmony_ci      self.opcode = m.group('opcode')
371bf215546Sopenharmony_ci      self._bit_size = int(m.group('bits')) if m.group('bits') else None
372bf215546Sopenharmony_ci      self.inexact = m.group('inexact') is not None
373bf215546Sopenharmony_ci      self.exact = m.group('exact') is not None
374bf215546Sopenharmony_ci      self.ignore_exact = expr.ignore_exact
375bf215546Sopenharmony_ci      self.cond = m.group('cond')
376bf215546Sopenharmony_ci
377bf215546Sopenharmony_ci      assert not self.inexact or not self.exact, \
378bf215546Sopenharmony_ci            'Expression cannot be both exact and inexact.'
379bf215546Sopenharmony_ci
380bf215546Sopenharmony_ci      # "many-comm-expr" isn't really a condition.  It's notification to the
381bf215546Sopenharmony_ci      # generator that this pattern is known to have too many commutative
382bf215546Sopenharmony_ci      # expressions, and an error should not be generated for this case.
383bf215546Sopenharmony_ci      self.many_commutative_expressions = False
384bf215546Sopenharmony_ci      if self.cond and self.cond.find("many-comm-expr") >= 0:
385bf215546Sopenharmony_ci         # Split the condition into a comma-separated list.  Remove
386bf215546Sopenharmony_ci         # "many-comm-expr".  If there is anything left, put it back together.
387bf215546Sopenharmony_ci         c = self.cond[1:-1].split(",")
388bf215546Sopenharmony_ci         c.remove("many-comm-expr")
389bf215546Sopenharmony_ci         assert(len(c) <= 1)
390bf215546Sopenharmony_ci
391bf215546Sopenharmony_ci         self.cond = c[0] if c else None
392bf215546Sopenharmony_ci         self.many_commutative_expressions = True
393bf215546Sopenharmony_ci
394bf215546Sopenharmony_ci      # Deduplicate references to the condition functions for the expressions
395bf215546Sopenharmony_ci      # and save the index for the order they were added.
396bf215546Sopenharmony_ci      self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
397bf215546Sopenharmony_ci
398bf215546Sopenharmony_ci      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
399bf215546Sopenharmony_ci                       for (i, src) in enumerate(expr.sources) ]
400bf215546Sopenharmony_ci
401bf215546Sopenharmony_ci      # nir_search_expression::srcs is hard-coded to 4
402bf215546Sopenharmony_ci      assert len(self.sources) <= 4
403bf215546Sopenharmony_ci
404bf215546Sopenharmony_ci      if self.opcode in conv_opcode_types:
405bf215546Sopenharmony_ci         assert self._bit_size is None, \
406bf215546Sopenharmony_ci                'Expression cannot use an unsized conversion opcode with ' \
407bf215546Sopenharmony_ci                'an explicit size; that\'s silly.'
408bf215546Sopenharmony_ci
409bf215546Sopenharmony_ci      self.__index_comm_exprs(0)
410bf215546Sopenharmony_ci
411bf215546Sopenharmony_ci   def equivalent(self, other):
412bf215546Sopenharmony_ci      """Check that two variables are equivalent.
413bf215546Sopenharmony_ci
414bf215546Sopenharmony_ci      This is check is much weaker than equality.  One generally cannot be
415bf215546Sopenharmony_ci      used in place of the other.  Using this implementation for the __eq__
416bf215546Sopenharmony_ci      will break BitSizeValidator.
417bf215546Sopenharmony_ci
418bf215546Sopenharmony_ci      This implementation does not check for equivalence due to commutativity,
419bf215546Sopenharmony_ci      but it could.
420bf215546Sopenharmony_ci
421bf215546Sopenharmony_ci      """
422bf215546Sopenharmony_ci      if not isinstance(other, type(self)):
423bf215546Sopenharmony_ci         return False
424bf215546Sopenharmony_ci
425bf215546Sopenharmony_ci      if len(self.sources) != len(other.sources):
426bf215546Sopenharmony_ci         return False
427bf215546Sopenharmony_ci
428bf215546Sopenharmony_ci      if self.opcode != other.opcode:
429bf215546Sopenharmony_ci         return False
430bf215546Sopenharmony_ci
431bf215546Sopenharmony_ci      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
432bf215546Sopenharmony_ci
433bf215546Sopenharmony_ci   def __index_comm_exprs(self, base_idx):
434bf215546Sopenharmony_ci      """Recursively count and index commutative expressions
435bf215546Sopenharmony_ci      """
436bf215546Sopenharmony_ci      self.comm_exprs = 0
437bf215546Sopenharmony_ci
438bf215546Sopenharmony_ci      # A note about the explicit "len(self.sources)" check. The list of
439bf215546Sopenharmony_ci      # sources comes from user input, and that input might be bad.  Check
440bf215546Sopenharmony_ci      # that the expected second source exists before accessing it. Without
441bf215546Sopenharmony_ci      # this check, a unit test that does "('iadd', 'a')" will crash.
442bf215546Sopenharmony_ci      if self.opcode not in conv_opcode_types and \
443bf215546Sopenharmony_ci         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
444bf215546Sopenharmony_ci         len(self.sources) >= 2 and \
445bf215546Sopenharmony_ci         not self.sources[0].equivalent(self.sources[1]):
446bf215546Sopenharmony_ci         self.comm_expr_idx = base_idx
447bf215546Sopenharmony_ci         self.comm_exprs += 1
448bf215546Sopenharmony_ci      else:
449bf215546Sopenharmony_ci         self.comm_expr_idx = -1
450bf215546Sopenharmony_ci
451bf215546Sopenharmony_ci      for s in self.sources:
452bf215546Sopenharmony_ci         if isinstance(s, Expression):
453bf215546Sopenharmony_ci            s.__index_comm_exprs(base_idx + self.comm_exprs)
454bf215546Sopenharmony_ci            self.comm_exprs += s.comm_exprs
455bf215546Sopenharmony_ci
456bf215546Sopenharmony_ci      return self.comm_exprs
457bf215546Sopenharmony_ci
458bf215546Sopenharmony_ci   def c_opcode(self):
459bf215546Sopenharmony_ci      return get_c_opcode(self.opcode)
460bf215546Sopenharmony_ci
461bf215546Sopenharmony_ci   def render(self, cache):
462bf215546Sopenharmony_ci      srcs = "".join(src.render(cache) for src in self.sources)
463bf215546Sopenharmony_ci      return srcs + super(Expression, self).render(cache)
464bf215546Sopenharmony_ci
465bf215546Sopenharmony_ciclass BitSizeValidator(object):
466bf215546Sopenharmony_ci   """A class for validating bit sizes of expressions.
467bf215546Sopenharmony_ci
468bf215546Sopenharmony_ci   NIR supports multiple bit-sizes on expressions in order to handle things
469bf215546Sopenharmony_ci   such as fp64.  The source and destination of every ALU operation is
470bf215546Sopenharmony_ci   assigned a type and that type may or may not specify a bit size.  Sources
471bf215546Sopenharmony_ci   and destinations whose type does not specify a bit size are considered
472bf215546Sopenharmony_ci   "unsized" and automatically take on the bit size of the corresponding
473bf215546Sopenharmony_ci   register or SSA value.  NIR has two simple rules for bit sizes that are
474bf215546Sopenharmony_ci   validated by nir_validator:
475bf215546Sopenharmony_ci
476bf215546Sopenharmony_ci    1) A given SSA def or register has a single bit size that is respected by
477bf215546Sopenharmony_ci       everything that reads from it or writes to it.
478bf215546Sopenharmony_ci
479bf215546Sopenharmony_ci    2) The bit sizes of all unsized inputs/outputs on any given ALU
480bf215546Sopenharmony_ci       instruction must match.  They need not match the sized inputs or
481bf215546Sopenharmony_ci       outputs but they must match each other.
482bf215546Sopenharmony_ci
483bf215546Sopenharmony_ci   In order to keep nir_algebraic relatively simple and easy-to-use,
484bf215546Sopenharmony_ci   nir_search supports a type of bit-size inference based on the two rules
485bf215546Sopenharmony_ci   above.  This is similar to type inference in many common programming
486bf215546Sopenharmony_ci   languages.  If, for instance, you are constructing an add operation and you
487bf215546Sopenharmony_ci   know the second source is 16-bit, then you know that the other source and
488bf215546Sopenharmony_ci   the destination must also be 16-bit.  There are, however, cases where this
489bf215546Sopenharmony_ci   inference can be ambiguous or contradictory.  Consider, for instance, the
490bf215546Sopenharmony_ci   following transformation:
491bf215546Sopenharmony_ci
492bf215546Sopenharmony_ci   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
493bf215546Sopenharmony_ci
494bf215546Sopenharmony_ci   This transformation can potentially cause a problem because usub_borrow is
495bf215546Sopenharmony_ci   well-defined for any bit-size of integer.  However, b2i always generates a
496bf215546Sopenharmony_ci   32-bit result so it could end up replacing a 64-bit expression with one
497bf215546Sopenharmony_ci   that takes two 64-bit values and produces a 32-bit value.  As another
498bf215546Sopenharmony_ci   example, consider this expression:
499bf215546Sopenharmony_ci
500bf215546Sopenharmony_ci   (('bcsel', a, b, 0), ('iand', a, b))
501bf215546Sopenharmony_ci
502bf215546Sopenharmony_ci   In this case, in the search expression a must be 32-bit but b can
503bf215546Sopenharmony_ci   potentially have any bit size.  If we had a 64-bit b value, we would end up
504bf215546Sopenharmony_ci   trying to and a 32-bit value with a 64-bit value which would be invalid
505bf215546Sopenharmony_ci
506bf215546Sopenharmony_ci   This class solves that problem by providing a validation layer that proves
507bf215546Sopenharmony_ci   that a given search-and-replace operation is 100% well-defined before we
508bf215546Sopenharmony_ci   generate any code.  This ensures that bugs are caught at compile time
509bf215546Sopenharmony_ci   rather than at run time.
510bf215546Sopenharmony_ci
511bf215546Sopenharmony_ci   Each value maintains a "bit-size class", which is either an actual bit size
512bf215546Sopenharmony_ci   or an equivalence class with other values that must have the same bit size.
513bf215546Sopenharmony_ci   The validator works by combining bit-size classes with each other according
514bf215546Sopenharmony_ci   to the NIR rules outlined above, checking that there are no inconsistencies.
515bf215546Sopenharmony_ci   When doing this for the replacement expression, we make sure to never change
516bf215546Sopenharmony_ci   the equivalence class of any of the search values. We could make the example
517bf215546Sopenharmony_ci   transforms above work by doing some extra run-time checking of the search
518bf215546Sopenharmony_ci   expression, but we make the user specify those constraints themselves, to
519bf215546Sopenharmony_ci   avoid any surprises. Since the replacement bitsizes can only be connected to
520bf215546Sopenharmony_ci   the source bitsize via variables (variables must have the same bitsize in
521bf215546Sopenharmony_ci   the source and replacment expressions) or the roots of the expression (the
522bf215546Sopenharmony_ci   replacement expression must produce the same bit size as the search
523bf215546Sopenharmony_ci   expression), we prevent merging a variable with anything when processing the
524bf215546Sopenharmony_ci   replacement expression, or specializing the search bitsize
525bf215546Sopenharmony_ci   with anything. The former prevents
526bf215546Sopenharmony_ci
527bf215546Sopenharmony_ci   (('bcsel', a, b, 0), ('iand', a, b))
528bf215546Sopenharmony_ci
529bf215546Sopenharmony_ci   from being allowed, since we'd have to merge the bitsizes for a and b due to
530bf215546Sopenharmony_ci   the 'iand', while the latter prevents
531bf215546Sopenharmony_ci
532bf215546Sopenharmony_ci   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
533bf215546Sopenharmony_ci
534bf215546Sopenharmony_ci   from being allowed, since the search expression has the bit size of a and b,
535bf215546Sopenharmony_ci   which can't be specialized to 32 which is the bitsize of the replace
536bf215546Sopenharmony_ci   expression. It also prevents something like:
537bf215546Sopenharmony_ci
538bf215546Sopenharmony_ci   (('b2i', ('i2b', a)), ('ineq', a, 0))
539bf215546Sopenharmony_ci
540bf215546Sopenharmony_ci   since the bitsize of 'b2i', which can be anything, can't be specialized to
541bf215546Sopenharmony_ci   the bitsize of a.
542bf215546Sopenharmony_ci
543bf215546Sopenharmony_ci   After doing all this, we check that every subexpression of the replacement
544bf215546Sopenharmony_ci   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
545bf215546Sopenharmony_ci   of the search expresssion, since those are the things that are known when
546bf215546Sopenharmony_ci   constructing the replacement expresssion. Finally, we record the bitsize
547bf215546Sopenharmony_ci   needed in nir_search_value so that we know what to do when building the
548bf215546Sopenharmony_ci   replacement expression.
549bf215546Sopenharmony_ci   """
550bf215546Sopenharmony_ci
551bf215546Sopenharmony_ci   def __init__(self, varset):
552bf215546Sopenharmony_ci      self._var_classes = [None] * len(varset.names)
553bf215546Sopenharmony_ci
554bf215546Sopenharmony_ci   def compare_bitsizes(self, a, b):
555bf215546Sopenharmony_ci      """Determines which bitsize class is a specialization of the other, or
556bf215546Sopenharmony_ci      whether neither is. When we merge two different bitsizes, the
557bf215546Sopenharmony_ci      less-specialized bitsize always points to the more-specialized one, so
558bf215546Sopenharmony_ci      that calling get_bit_size() always gets you the most specialized bitsize.
559bf215546Sopenharmony_ci      The specialization partial order is given by:
560bf215546Sopenharmony_ci      - Physical bitsizes are always the most specialized, and a different
561bf215546Sopenharmony_ci        bitsize can never specialize another.
562bf215546Sopenharmony_ci      - In the search expression, variables can always be specialized to each
563bf215546Sopenharmony_ci        other and to physical bitsizes. In the replace expression, we disallow
564bf215546Sopenharmony_ci        this to avoid adding extra constraints to the search expression that
565bf215546Sopenharmony_ci        the user didn't specify.
566bf215546Sopenharmony_ci      - Expressions and constants without a bitsize can always be specialized to
567bf215546Sopenharmony_ci        each other and variables, but not the other way around.
568bf215546Sopenharmony_ci
569bf215546Sopenharmony_ci        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
570bf215546Sopenharmony_ci        and None if they are not comparable (neither a <= b nor b <= a).
571bf215546Sopenharmony_ci      """
572bf215546Sopenharmony_ci      if isinstance(a, int):
573bf215546Sopenharmony_ci         if isinstance(b, int):
574bf215546Sopenharmony_ci            return 0 if a == b else None
575bf215546Sopenharmony_ci         elif isinstance(b, Variable):
576bf215546Sopenharmony_ci            return -1 if self.is_search else None
577bf215546Sopenharmony_ci         else:
578bf215546Sopenharmony_ci            return -1
579bf215546Sopenharmony_ci      elif isinstance(a, Variable):
580bf215546Sopenharmony_ci         if isinstance(b, int):
581bf215546Sopenharmony_ci            return 1 if self.is_search else None
582bf215546Sopenharmony_ci         elif isinstance(b, Variable):
583bf215546Sopenharmony_ci            return 0 if self.is_search or a.index == b.index else None
584bf215546Sopenharmony_ci         else:
585bf215546Sopenharmony_ci            return -1
586bf215546Sopenharmony_ci      else:
587bf215546Sopenharmony_ci         if isinstance(b, int):
588bf215546Sopenharmony_ci            return 1
589bf215546Sopenharmony_ci         elif isinstance(b, Variable):
590bf215546Sopenharmony_ci            return 1
591bf215546Sopenharmony_ci         else:
592bf215546Sopenharmony_ci            return 0
593bf215546Sopenharmony_ci
594bf215546Sopenharmony_ci   def unify_bit_size(self, a, b, error_msg):
595bf215546Sopenharmony_ci      """Record that a must have the same bit-size as b. If both
596bf215546Sopenharmony_ci      have been assigned conflicting physical bit-sizes, call "error_msg" with
597bf215546Sopenharmony_ci      the bit-sizes of self and other to get a message and raise an error.
598bf215546Sopenharmony_ci      In the replace expression, disallow merging variables with other
599bf215546Sopenharmony_ci      variables and physical bit-sizes as well.
600bf215546Sopenharmony_ci      """
601bf215546Sopenharmony_ci      a_bit_size = a.get_bit_size()
602bf215546Sopenharmony_ci      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
603bf215546Sopenharmony_ci
604bf215546Sopenharmony_ci      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
605bf215546Sopenharmony_ci
606bf215546Sopenharmony_ci      assert cmp_result is not None, \
607bf215546Sopenharmony_ci         error_msg(a_bit_size, b_bit_size)
608bf215546Sopenharmony_ci
609bf215546Sopenharmony_ci      if cmp_result < 0:
610bf215546Sopenharmony_ci         b_bit_size.set_bit_size(a)
611bf215546Sopenharmony_ci      elif not isinstance(a_bit_size, int):
612bf215546Sopenharmony_ci         a_bit_size.set_bit_size(b)
613bf215546Sopenharmony_ci
614bf215546Sopenharmony_ci   def merge_variables(self, val):
615bf215546Sopenharmony_ci      """Perform the first part of type inference by merging all the different
616bf215546Sopenharmony_ci      uses of the same variable. We always do this as if we're in the search
617bf215546Sopenharmony_ci      expression, even if we're actually not, since otherwise we'd get errors
618bf215546Sopenharmony_ci      if the search expression specified some constraint but the replace
619bf215546Sopenharmony_ci      expression didn't, because we'd be merging a variable and a constant.
620bf215546Sopenharmony_ci      """
621bf215546Sopenharmony_ci      if isinstance(val, Variable):
622bf215546Sopenharmony_ci         if self._var_classes[val.index] is None:
623bf215546Sopenharmony_ci            self._var_classes[val.index] = val
624bf215546Sopenharmony_ci         else:
625bf215546Sopenharmony_ci            other = self._var_classes[val.index]
626bf215546Sopenharmony_ci            self.unify_bit_size(other, val,
627bf215546Sopenharmony_ci                  lambda other_bit_size, bit_size:
628bf215546Sopenharmony_ci                     'Variable {} has conflicting bit size requirements: ' \
629bf215546Sopenharmony_ci                     'it must have bit size {} and {}'.format(
630bf215546Sopenharmony_ci                        val.var_name, other_bit_size, bit_size))
631bf215546Sopenharmony_ci      elif isinstance(val, Expression):
632bf215546Sopenharmony_ci         for src in val.sources:
633bf215546Sopenharmony_ci            self.merge_variables(src)
634bf215546Sopenharmony_ci
635bf215546Sopenharmony_ci   def validate_value(self, val):
636bf215546Sopenharmony_ci      """Validate the an expression by performing classic Hindley-Milner
637bf215546Sopenharmony_ci      type inference on bitsizes. This will detect if there are any conflicting
638bf215546Sopenharmony_ci      requirements, and unify variables so that we know which variables must
639bf215546Sopenharmony_ci      have the same bitsize. If we're operating on the replace expression, we
640bf215546Sopenharmony_ci      will refuse to merge different variables together or merge a variable
641bf215546Sopenharmony_ci      with a constant, in order to prevent surprises due to rules unexpectedly
642bf215546Sopenharmony_ci      not matching at runtime.
643bf215546Sopenharmony_ci      """
644bf215546Sopenharmony_ci      if not isinstance(val, Expression):
645bf215546Sopenharmony_ci         return
646bf215546Sopenharmony_ci
647bf215546Sopenharmony_ci      # Generic conversion ops are special in that they have a single unsized
648bf215546Sopenharmony_ci      # source and an unsized destination and the two don't have to match.
649bf215546Sopenharmony_ci      # This means there's no validation or unioning to do here besides the
650bf215546Sopenharmony_ci      # len(val.sources) check.
651bf215546Sopenharmony_ci      if val.opcode in conv_opcode_types:
652bf215546Sopenharmony_ci         assert len(val.sources) == 1, \
653bf215546Sopenharmony_ci            "Expression {} has {} sources, expected 1".format(
654bf215546Sopenharmony_ci               val, len(val.sources))
655bf215546Sopenharmony_ci         self.validate_value(val.sources[0])
656bf215546Sopenharmony_ci         return
657bf215546Sopenharmony_ci
658bf215546Sopenharmony_ci      nir_op = opcodes[val.opcode]
659bf215546Sopenharmony_ci      assert len(val.sources) == nir_op.num_inputs, \
660bf215546Sopenharmony_ci         "Expression {} has {} sources, expected {}".format(
661bf215546Sopenharmony_ci            val, len(val.sources), nir_op.num_inputs)
662bf215546Sopenharmony_ci
663bf215546Sopenharmony_ci      for src in val.sources:
664bf215546Sopenharmony_ci         self.validate_value(src)
665bf215546Sopenharmony_ci
666bf215546Sopenharmony_ci      dst_type_bits = type_bits(nir_op.output_type)
667bf215546Sopenharmony_ci
668bf215546Sopenharmony_ci      # First, unify all the sources. That way, an error coming up because two
669bf215546Sopenharmony_ci      # sources have an incompatible bit-size won't produce an error message
670bf215546Sopenharmony_ci      # involving the destination.
671bf215546Sopenharmony_ci      first_unsized_src = None
672bf215546Sopenharmony_ci      for src_type, src in zip(nir_op.input_types, val.sources):
673bf215546Sopenharmony_ci         src_type_bits = type_bits(src_type)
674bf215546Sopenharmony_ci         if src_type_bits == 0:
675bf215546Sopenharmony_ci            if first_unsized_src is None:
676bf215546Sopenharmony_ci               first_unsized_src = src
677bf215546Sopenharmony_ci               continue
678bf215546Sopenharmony_ci
679bf215546Sopenharmony_ci            if self.is_search:
680bf215546Sopenharmony_ci               self.unify_bit_size(first_unsized_src, src,
681bf215546Sopenharmony_ci                  lambda first_unsized_src_bit_size, src_bit_size:
682bf215546Sopenharmony_ci                     'Source {} of {} must have bit size {}, while source {} ' \
683bf215546Sopenharmony_ci                     'must have incompatible bit size {}'.format(
684bf215546Sopenharmony_ci                        first_unsized_src, val, first_unsized_src_bit_size,
685bf215546Sopenharmony_ci                        src, src_bit_size))
686bf215546Sopenharmony_ci            else:
687bf215546Sopenharmony_ci               self.unify_bit_size(first_unsized_src, src,
688bf215546Sopenharmony_ci                  lambda first_unsized_src_bit_size, src_bit_size:
689bf215546Sopenharmony_ci                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
690bf215546Sopenharmony_ci                     'of {} may not have the same bit size when building the ' \
691bf215546Sopenharmony_ci                     'replacement expression.'.format(
692bf215546Sopenharmony_ci                        first_unsized_src, first_unsized_src_bit_size, src,
693bf215546Sopenharmony_ci                        src_bit_size, val))
694bf215546Sopenharmony_ci         else:
695bf215546Sopenharmony_ci            if self.is_search:
696bf215546Sopenharmony_ci               self.unify_bit_size(src, src_type_bits,
697bf215546Sopenharmony_ci                  lambda src_bit_size, unused:
698bf215546Sopenharmony_ci                     '{} must have {} bits, but as a source of nir_op_{} '\
699bf215546Sopenharmony_ci                     'it must have {} bits'.format(
700bf215546Sopenharmony_ci                        src, src_bit_size, nir_op.name, src_type_bits))
701bf215546Sopenharmony_ci            else:
702bf215546Sopenharmony_ci               self.unify_bit_size(src, src_type_bits,
703bf215546Sopenharmony_ci                  lambda src_bit_size, unused:
704bf215546Sopenharmony_ci                     '{} has the bit size of {}, but as a source of ' \
705bf215546Sopenharmony_ci                     'nir_op_{} it must have {} bits, which may not be the ' \
706bf215546Sopenharmony_ci                     'same'.format(
707bf215546Sopenharmony_ci                        src, src_bit_size, nir_op.name, src_type_bits))
708bf215546Sopenharmony_ci
709bf215546Sopenharmony_ci      if dst_type_bits == 0:
710bf215546Sopenharmony_ci         if first_unsized_src is not None:
711bf215546Sopenharmony_ci            if self.is_search:
712bf215546Sopenharmony_ci               self.unify_bit_size(val, first_unsized_src,
713bf215546Sopenharmony_ci                  lambda val_bit_size, src_bit_size:
714bf215546Sopenharmony_ci                     '{} must have the bit size of {}, while its source {} ' \
715bf215546Sopenharmony_ci                     'must have incompatible bit size {}'.format(
716bf215546Sopenharmony_ci                        val, val_bit_size, first_unsized_src, src_bit_size))
717bf215546Sopenharmony_ci            else:
718bf215546Sopenharmony_ci               self.unify_bit_size(val, first_unsized_src,
719bf215546Sopenharmony_ci                  lambda val_bit_size, src_bit_size:
720bf215546Sopenharmony_ci                     '{} must have {} bits, but its source {} ' \
721bf215546Sopenharmony_ci                     '(bit size of {}) may not have that bit size ' \
722bf215546Sopenharmony_ci                     'when building the replacement.'.format(
723bf215546Sopenharmony_ci                        val, val_bit_size, first_unsized_src, src_bit_size))
724bf215546Sopenharmony_ci      else:
725bf215546Sopenharmony_ci         self.unify_bit_size(val, dst_type_bits,
726bf215546Sopenharmony_ci            lambda dst_bit_size, unused:
727bf215546Sopenharmony_ci               '{} must have {} bits, but as a destination of nir_op_{} ' \
728bf215546Sopenharmony_ci               'it must have {} bits'.format(
729bf215546Sopenharmony_ci                  val, dst_bit_size, nir_op.name, dst_type_bits))
730bf215546Sopenharmony_ci
731bf215546Sopenharmony_ci   def validate_replace(self, val, search):
732bf215546Sopenharmony_ci      bit_size = val.get_bit_size()
733bf215546Sopenharmony_ci      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
734bf215546Sopenharmony_ci            bit_size == search.get_bit_size(), \
735bf215546Sopenharmony_ci            'Ambiguous bit size for replacement value {}: ' \
736bf215546Sopenharmony_ci            'it cannot be deduced from a variable, a fixed bit size ' \
737bf215546Sopenharmony_ci            'somewhere, or the search expression.'.format(val)
738bf215546Sopenharmony_ci
739bf215546Sopenharmony_ci      if isinstance(val, Expression):
740bf215546Sopenharmony_ci         for src in val.sources:
741bf215546Sopenharmony_ci            self.validate_replace(src, search)
742bf215546Sopenharmony_ci
743bf215546Sopenharmony_ci   def validate(self, search, replace):
744bf215546Sopenharmony_ci      self.is_search = True
745bf215546Sopenharmony_ci      self.merge_variables(search)
746bf215546Sopenharmony_ci      self.merge_variables(replace)
747bf215546Sopenharmony_ci      self.validate_value(search)
748bf215546Sopenharmony_ci
749bf215546Sopenharmony_ci      self.is_search = False
750bf215546Sopenharmony_ci      self.validate_value(replace)
751bf215546Sopenharmony_ci
752bf215546Sopenharmony_ci      # Check that search is always more specialized than replace. Note that
753bf215546Sopenharmony_ci      # we're doing this in replace mode, disallowing merging variables.
754bf215546Sopenharmony_ci      search_bit_size = search.get_bit_size()
755bf215546Sopenharmony_ci      replace_bit_size = replace.get_bit_size()
756bf215546Sopenharmony_ci      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
757bf215546Sopenharmony_ci
758bf215546Sopenharmony_ci      assert cmp_result is not None and cmp_result <= 0, \
759bf215546Sopenharmony_ci         'The search expression bit size {} and replace expression ' \
760bf215546Sopenharmony_ci         'bit size {} may not be the same'.format(
761bf215546Sopenharmony_ci               search_bit_size, replace_bit_size)
762bf215546Sopenharmony_ci
763bf215546Sopenharmony_ci      replace.set_bit_size(search)
764bf215546Sopenharmony_ci
765bf215546Sopenharmony_ci      self.validate_replace(replace, search)
766bf215546Sopenharmony_ci
767bf215546Sopenharmony_ci_optimization_ids = itertools.count()
768bf215546Sopenharmony_ci
769bf215546Sopenharmony_cicondition_list = ['true']
770bf215546Sopenharmony_ci
771bf215546Sopenharmony_ciclass SearchAndReplace(object):
772bf215546Sopenharmony_ci   def __init__(self, transform, algebraic_pass):
773bf215546Sopenharmony_ci      self.id = next(_optimization_ids)
774bf215546Sopenharmony_ci
775bf215546Sopenharmony_ci      search = transform[0]
776bf215546Sopenharmony_ci      replace = transform[1]
777bf215546Sopenharmony_ci      if len(transform) > 2:
778bf215546Sopenharmony_ci         self.condition = transform[2]
779bf215546Sopenharmony_ci      else:
780bf215546Sopenharmony_ci         self.condition = 'true'
781bf215546Sopenharmony_ci
782bf215546Sopenharmony_ci      if self.condition not in condition_list:
783bf215546Sopenharmony_ci         condition_list.append(self.condition)
784bf215546Sopenharmony_ci      self.condition_index = condition_list.index(self.condition)
785bf215546Sopenharmony_ci
786bf215546Sopenharmony_ci      varset = VarSet()
787bf215546Sopenharmony_ci      if isinstance(search, Expression):
788bf215546Sopenharmony_ci         self.search = search
789bf215546Sopenharmony_ci      else:
790bf215546Sopenharmony_ci         self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
791bf215546Sopenharmony_ci
792bf215546Sopenharmony_ci      varset.lock()
793bf215546Sopenharmony_ci
794bf215546Sopenharmony_ci      if isinstance(replace, Value):
795bf215546Sopenharmony_ci         self.replace = replace
796bf215546Sopenharmony_ci      else:
797bf215546Sopenharmony_ci         self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
798bf215546Sopenharmony_ci
799bf215546Sopenharmony_ci      BitSizeValidator(varset).validate(self.search, self.replace)
800bf215546Sopenharmony_ci
801bf215546Sopenharmony_ciclass TreeAutomaton(object):
802bf215546Sopenharmony_ci   """This class calculates a bottom-up tree automaton to quickly search for
803bf215546Sopenharmony_ci   the left-hand sides of tranforms. Tree automatons are a generalization of
804bf215546Sopenharmony_ci   classical NFA's and DFA's, where the transition function determines the
805bf215546Sopenharmony_ci   state of the parent node based on the state of its children. We construct a
806bf215546Sopenharmony_ci   deterministic automaton to match patterns, using a similar algorithm to the
807bf215546Sopenharmony_ci   classical NFA to DFA construction. At the moment, it only matches opcodes
808bf215546Sopenharmony_ci   and constants (without checking the actual value), leaving more detailed
809bf215546Sopenharmony_ci   checking to the search function which actually checks the leaves. The
810bf215546Sopenharmony_ci   automaton acts as a quick filter for the search function, requiring only n
811bf215546Sopenharmony_ci   + 1 table lookups for each n-source operation. The implementation is based
812bf215546Sopenharmony_ci   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
813bf215546Sopenharmony_ci   In the language of that reference, this is a frontier-to-root deterministic
814bf215546Sopenharmony_ci   automaton using only symbol filtering. The filtering is crucial to reduce
815bf215546Sopenharmony_ci   both the time taken to generate the tables and the size of the tables.
816bf215546Sopenharmony_ci   """
817bf215546Sopenharmony_ci   def __init__(self, transforms):
818bf215546Sopenharmony_ci      self.patterns = [t.search for t in transforms]
819bf215546Sopenharmony_ci      self._compute_items()
820bf215546Sopenharmony_ci      self._build_table()
821bf215546Sopenharmony_ci      #print('num items: {}'.format(len(set(self.items.values()))))
822bf215546Sopenharmony_ci      #print('num states: {}'.format(len(self.states)))
823bf215546Sopenharmony_ci      #for state, patterns in zip(self.states, self.patterns):
824bf215546Sopenharmony_ci      #   print('{}: num patterns: {}'.format(state, len(patterns)))
825bf215546Sopenharmony_ci
826bf215546Sopenharmony_ci   class IndexMap(object):
827bf215546Sopenharmony_ci      """An indexed list of objects, where one can either lookup an object by
828bf215546Sopenharmony_ci      index or find the index associated to an object quickly using a hash
829bf215546Sopenharmony_ci      table. Compared to a list, it has a constant time index(). Compared to a
830bf215546Sopenharmony_ci      set, it provides a stable iteration order.
831bf215546Sopenharmony_ci      """
832bf215546Sopenharmony_ci      def __init__(self, iterable=()):
833bf215546Sopenharmony_ci         self.objects = []
834bf215546Sopenharmony_ci         self.map = {}
835bf215546Sopenharmony_ci         for obj in iterable:
836bf215546Sopenharmony_ci            self.add(obj)
837bf215546Sopenharmony_ci
838bf215546Sopenharmony_ci      def __getitem__(self, i):
839bf215546Sopenharmony_ci         return self.objects[i]
840bf215546Sopenharmony_ci
841bf215546Sopenharmony_ci      def __contains__(self, obj):
842bf215546Sopenharmony_ci         return obj in self.map
843bf215546Sopenharmony_ci
844bf215546Sopenharmony_ci      def __len__(self):
845bf215546Sopenharmony_ci         return len(self.objects)
846bf215546Sopenharmony_ci
847bf215546Sopenharmony_ci      def __iter__(self):
848bf215546Sopenharmony_ci         return iter(self.objects)
849bf215546Sopenharmony_ci
850bf215546Sopenharmony_ci      def clear(self):
851bf215546Sopenharmony_ci         self.objects = []
852bf215546Sopenharmony_ci         self.map.clear()
853bf215546Sopenharmony_ci
854bf215546Sopenharmony_ci      def index(self, obj):
855bf215546Sopenharmony_ci         return self.map[obj]
856bf215546Sopenharmony_ci
857bf215546Sopenharmony_ci      def add(self, obj):
858bf215546Sopenharmony_ci         if obj in self.map:
859bf215546Sopenharmony_ci            return self.map[obj]
860bf215546Sopenharmony_ci         else:
861bf215546Sopenharmony_ci            index = len(self.objects)
862bf215546Sopenharmony_ci            self.objects.append(obj)
863bf215546Sopenharmony_ci            self.map[obj] = index
864bf215546Sopenharmony_ci            return index
865bf215546Sopenharmony_ci
866bf215546Sopenharmony_ci      def __repr__(self):
867bf215546Sopenharmony_ci         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
868bf215546Sopenharmony_ci
869bf215546Sopenharmony_ci   class Item(object):
870bf215546Sopenharmony_ci      """This represents an "item" in the language of "Tree Automatons." This
871bf215546Sopenharmony_ci      is just a subtree of some pattern, which represents a potential partial
872bf215546Sopenharmony_ci      match at runtime. We deduplicate them, so that identical subtrees of
873bf215546Sopenharmony_ci      different patterns share the same object, and store some extra
874bf215546Sopenharmony_ci      information needed for the main algorithm as well.
875bf215546Sopenharmony_ci      """
876bf215546Sopenharmony_ci      def __init__(self, opcode, children):
877bf215546Sopenharmony_ci         self.opcode = opcode
878bf215546Sopenharmony_ci         self.children = children
879bf215546Sopenharmony_ci         # These are the indices of patterns for which this item is the root node.
880bf215546Sopenharmony_ci         self.patterns = []
881bf215546Sopenharmony_ci         # This the set of opcodes for parents of this item. Used to speed up
882bf215546Sopenharmony_ci         # filtering.
883bf215546Sopenharmony_ci         self.parent_ops = set()
884bf215546Sopenharmony_ci
885bf215546Sopenharmony_ci      def __str__(self):
886bf215546Sopenharmony_ci         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
887bf215546Sopenharmony_ci
888bf215546Sopenharmony_ci      def __repr__(self):
889bf215546Sopenharmony_ci         return str(self)
890bf215546Sopenharmony_ci
891bf215546Sopenharmony_ci   def _compute_items(self):
892bf215546Sopenharmony_ci      """Build a set of all possible items, deduplicating them."""
893bf215546Sopenharmony_ci      # This is a map from (opcode, sources) to item.
894bf215546Sopenharmony_ci      self.items = {}
895bf215546Sopenharmony_ci
896bf215546Sopenharmony_ci      # The set of all opcodes used by the patterns. Used later to avoid
897bf215546Sopenharmony_ci      # building and emitting all the tables for opcodes that aren't used.
898bf215546Sopenharmony_ci      self.opcodes = self.IndexMap()
899bf215546Sopenharmony_ci
900bf215546Sopenharmony_ci      def get_item(opcode, children, pattern=None):
901bf215546Sopenharmony_ci         commutative = len(children) >= 2 \
902bf215546Sopenharmony_ci               and "2src_commutative" in opcodes[opcode].algebraic_properties
903bf215546Sopenharmony_ci         item = self.items.setdefault((opcode, children),
904bf215546Sopenharmony_ci                                      self.Item(opcode, children))
905bf215546Sopenharmony_ci         if commutative:
906bf215546Sopenharmony_ci            self.items[opcode, (children[1], children[0]) + children[2:]] = item
907bf215546Sopenharmony_ci         if pattern is not None:
908bf215546Sopenharmony_ci            item.patterns.append(pattern)
909bf215546Sopenharmony_ci         return item
910bf215546Sopenharmony_ci
911bf215546Sopenharmony_ci      self.wildcard = get_item("__wildcard", ())
912bf215546Sopenharmony_ci      self.const = get_item("__const", ())
913bf215546Sopenharmony_ci
914bf215546Sopenharmony_ci      def process_subpattern(src, pattern=None):
915bf215546Sopenharmony_ci         if isinstance(src, Constant):
916bf215546Sopenharmony_ci            # Note: we throw away the actual constant value!
917bf215546Sopenharmony_ci            return self.const
918bf215546Sopenharmony_ci         elif isinstance(src, Variable):
919bf215546Sopenharmony_ci            if src.is_constant:
920bf215546Sopenharmony_ci               return self.const
921bf215546Sopenharmony_ci            else:
922bf215546Sopenharmony_ci               # Note: we throw away which variable it is here! This special
923bf215546Sopenharmony_ci               # item is equivalent to nu in "Tree Automatons."
924bf215546Sopenharmony_ci               return self.wildcard
925bf215546Sopenharmony_ci         else:
926bf215546Sopenharmony_ci            assert isinstance(src, Expression)
927bf215546Sopenharmony_ci            opcode = src.opcode
928bf215546Sopenharmony_ci            stripped = opcode.rstrip('0123456789')
929bf215546Sopenharmony_ci            if stripped in conv_opcode_types:
930bf215546Sopenharmony_ci               # Matches that use conversion opcodes with a specific type,
931bf215546Sopenharmony_ci               # like f2b1, are tricky.  Either we construct the automaton to
932bf215546Sopenharmony_ci               # match specific NIR opcodes like nir_op_f2b1, in which case we
933bf215546Sopenharmony_ci               # need to create separate items for each possible NIR opcode
934bf215546Sopenharmony_ci               # for patterns that have a generic opcode like f2b, or we
935bf215546Sopenharmony_ci               # construct it to match the search opcode, in which case we
936bf215546Sopenharmony_ci               # need to map f2b1 to f2b when constructing the automaton. Here
937bf215546Sopenharmony_ci               # we do the latter.
938bf215546Sopenharmony_ci               opcode = stripped
939bf215546Sopenharmony_ci            self.opcodes.add(opcode)
940bf215546Sopenharmony_ci            children = tuple(process_subpattern(c) for c in src.sources)
941bf215546Sopenharmony_ci            item = get_item(opcode, children, pattern)
942bf215546Sopenharmony_ci            for i, child in enumerate(children):
943bf215546Sopenharmony_ci               child.parent_ops.add(opcode)
944bf215546Sopenharmony_ci            return item
945bf215546Sopenharmony_ci
946bf215546Sopenharmony_ci      for i, pattern in enumerate(self.patterns):
947bf215546Sopenharmony_ci         process_subpattern(pattern, i)
948bf215546Sopenharmony_ci
949bf215546Sopenharmony_ci   def _build_table(self):
950bf215546Sopenharmony_ci      """This is the core algorithm which builds up the transition table. It
951bf215546Sopenharmony_ci      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
952bf215546Sopenharmony_ci      Comp_a and Filt_{a,i} using integers to identify match sets." It
953bf215546Sopenharmony_ci      simultaneously builds up a list of all possible "match sets" or
954bf215546Sopenharmony_ci      "states", where each match set represents the set of Item's that match a
955bf215546Sopenharmony_ci      given instruction, and builds up the transition table between states.
956bf215546Sopenharmony_ci      """
957bf215546Sopenharmony_ci      # Map from opcode + filtered state indices to transitioned state.
958bf215546Sopenharmony_ci      self.table = defaultdict(dict)
959bf215546Sopenharmony_ci      # Bijection from state to index. q in the original algorithm is
960bf215546Sopenharmony_ci      # len(self.states)
961bf215546Sopenharmony_ci      self.states = self.IndexMap()
962bf215546Sopenharmony_ci      # Lists of pattern matches separated by None
963bf215546Sopenharmony_ci      self.state_patterns = [None]
964bf215546Sopenharmony_ci      # Offset in the ->transforms table for each state index
965bf215546Sopenharmony_ci      self.state_pattern_offsets = []
966bf215546Sopenharmony_ci      # Map from state index to filtered state index for each opcode.
967bf215546Sopenharmony_ci      self.filter = defaultdict(list)
968bf215546Sopenharmony_ci      # Bijections from filtered state to filtered state index for each
969bf215546Sopenharmony_ci      # opcode, called the "representor sets" in the original algorithm.
970bf215546Sopenharmony_ci      # q_{a,j} in the original algorithm is len(self.rep[op]).
971bf215546Sopenharmony_ci      self.rep = defaultdict(self.IndexMap)
972bf215546Sopenharmony_ci
973bf215546Sopenharmony_ci      # Everything in self.states with a index at least worklist_index is part
974bf215546Sopenharmony_ci      # of the worklist of newly created states. There is also a worklist of
975bf215546Sopenharmony_ci      # newly fitered states for each opcode, for which worklist_indices
976bf215546Sopenharmony_ci      # serves a similar purpose. worklist_index corresponds to p in the
977bf215546Sopenharmony_ci      # original algorithm, while worklist_indices is p_{a,j} (although since
978bf215546Sopenharmony_ci      # we only filter by opcode/symbol, it's really just p_a).
979bf215546Sopenharmony_ci      self.worklist_index = 0
980bf215546Sopenharmony_ci      worklist_indices = defaultdict(lambda: 0)
981bf215546Sopenharmony_ci
982bf215546Sopenharmony_ci      # This is the set of opcodes for which the filtered worklist is non-empty.
983bf215546Sopenharmony_ci      # It's used to avoid scanning opcodes for which there is nothing to
984bf215546Sopenharmony_ci      # process when building the transition table. It corresponds to new_a in
985bf215546Sopenharmony_ci      # the original algorithm.
986bf215546Sopenharmony_ci      new_opcodes = self.IndexMap()
987bf215546Sopenharmony_ci
988bf215546Sopenharmony_ci      # Process states on the global worklist, filtering them for each opcode,
989bf215546Sopenharmony_ci      # updating the filter tables, and updating the filtered worklists if any
990bf215546Sopenharmony_ci      # new filtered states are found. Similar to ComputeRepresenterSets() in
991bf215546Sopenharmony_ci      # the original algorithm, although that only processes a single state.
992bf215546Sopenharmony_ci      def process_new_states():
993bf215546Sopenharmony_ci         while self.worklist_index < len(self.states):
994bf215546Sopenharmony_ci            state = self.states[self.worklist_index]
995bf215546Sopenharmony_ci            # Calculate pattern matches for this state. Each pattern is
996bf215546Sopenharmony_ci            # assigned to a unique item, so we don't have to worry about
997bf215546Sopenharmony_ci            # deduplicating them here. However, we do have to sort them so
998bf215546Sopenharmony_ci            # that they're visited at runtime in the order they're specified
999bf215546Sopenharmony_ci            # in the source.
1000bf215546Sopenharmony_ci            patterns = list(sorted(p for item in state for p in item.patterns))
1001bf215546Sopenharmony_ci
1002bf215546Sopenharmony_ci            if patterns:
1003bf215546Sopenharmony_ci                # Add our patterns to the global table.
1004bf215546Sopenharmony_ci                self.state_pattern_offsets.append(len(self.state_patterns))
1005bf215546Sopenharmony_ci                self.state_patterns.extend(patterns)
1006bf215546Sopenharmony_ci                self.state_patterns.append(None)
1007bf215546Sopenharmony_ci            else:
1008bf215546Sopenharmony_ci                # Point to the initial sentinel in the global table.
1009bf215546Sopenharmony_ci                self.state_pattern_offsets.append(0)
1010bf215546Sopenharmony_ci
1011bf215546Sopenharmony_ci            # calculate filter table for this state, and update filtered
1012bf215546Sopenharmony_ci            # worklists.
1013bf215546Sopenharmony_ci            for op in self.opcodes:
1014bf215546Sopenharmony_ci               filt = self.filter[op]
1015bf215546Sopenharmony_ci               rep = self.rep[op]
1016bf215546Sopenharmony_ci               filtered = frozenset(item for item in state if \
1017bf215546Sopenharmony_ci                  op in item.parent_ops)
1018bf215546Sopenharmony_ci               if filtered in rep:
1019bf215546Sopenharmony_ci                  rep_index = rep.index(filtered)
1020bf215546Sopenharmony_ci               else:
1021bf215546Sopenharmony_ci                  rep_index = rep.add(filtered)
1022bf215546Sopenharmony_ci                  new_opcodes.add(op)
1023bf215546Sopenharmony_ci               assert len(filt) == self.worklist_index
1024bf215546Sopenharmony_ci               filt.append(rep_index)
1025bf215546Sopenharmony_ci            self.worklist_index += 1
1026bf215546Sopenharmony_ci
1027bf215546Sopenharmony_ci      # There are two start states: one which can only match as a wildcard,
1028bf215546Sopenharmony_ci      # and one which can match as a wildcard or constant. These will be the
1029bf215546Sopenharmony_ci      # states of intrinsics/other instructions and load_const instructions,
1030bf215546Sopenharmony_ci      # respectively. The indices of these must match the definitions of
1031bf215546Sopenharmony_ci      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1032bf215546Sopenharmony_ci      # initialize things correctly.
1033bf215546Sopenharmony_ci      self.states.add(frozenset((self.wildcard,)))
1034bf215546Sopenharmony_ci      self.states.add(frozenset((self.const,self.wildcard)))
1035bf215546Sopenharmony_ci      process_new_states()
1036bf215546Sopenharmony_ci
1037bf215546Sopenharmony_ci      while len(new_opcodes) > 0:
1038bf215546Sopenharmony_ci         for op in new_opcodes:
1039bf215546Sopenharmony_ci            rep = self.rep[op]
1040bf215546Sopenharmony_ci            table = self.table[op]
1041bf215546Sopenharmony_ci            op_worklist_index = worklist_indices[op]
1042bf215546Sopenharmony_ci            if op in conv_opcode_types:
1043bf215546Sopenharmony_ci               num_srcs = 1
1044bf215546Sopenharmony_ci            else:
1045bf215546Sopenharmony_ci               num_srcs = opcodes[op].num_inputs
1046bf215546Sopenharmony_ci
1047bf215546Sopenharmony_ci            # Iterate over all possible source combinations where at least one
1048bf215546Sopenharmony_ci            # is on the worklist.
1049bf215546Sopenharmony_ci            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1050bf215546Sopenharmony_ci               if all(src_idx < op_worklist_index for src_idx in src_indices):
1051bf215546Sopenharmony_ci                  continue
1052bf215546Sopenharmony_ci
1053bf215546Sopenharmony_ci               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1054bf215546Sopenharmony_ci
1055bf215546Sopenharmony_ci               # Try all possible pairings of source items and add the
1056bf215546Sopenharmony_ci               # corresponding parent items. This is Comp_a from the paper.
1057bf215546Sopenharmony_ci               parent = set(self.items[op, item_srcs] for item_srcs in
1058bf215546Sopenharmony_ci                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1059bf215546Sopenharmony_ci
1060bf215546Sopenharmony_ci               # We could always start matching something else with a
1061bf215546Sopenharmony_ci               # wildcard. This is Cl from the paper.
1062bf215546Sopenharmony_ci               parent.add(self.wildcard)
1063bf215546Sopenharmony_ci
1064bf215546Sopenharmony_ci               table[src_indices] = self.states.add(frozenset(parent))
1065bf215546Sopenharmony_ci            worklist_indices[op] = len(rep)
1066bf215546Sopenharmony_ci         new_opcodes.clear()
1067bf215546Sopenharmony_ci         process_new_states()
1068bf215546Sopenharmony_ci
1069bf215546Sopenharmony_ci_algebraic_pass_template = mako.template.Template("""
1070bf215546Sopenharmony_ci#include "nir.h"
1071bf215546Sopenharmony_ci#include "nir_builder.h"
1072bf215546Sopenharmony_ci#include "nir_search.h"
1073bf215546Sopenharmony_ci#include "nir_search_helpers.h"
1074bf215546Sopenharmony_ci
1075bf215546Sopenharmony_ci/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1076bf215546Sopenharmony_ci * transforms:
1077bf215546Sopenharmony_ci% for xform in xforms:
1078bf215546Sopenharmony_ci *    ${xform.search} => ${xform.replace}
1079bf215546Sopenharmony_ci% endfor
1080bf215546Sopenharmony_ci */
1081bf215546Sopenharmony_ci
1082bf215546Sopenharmony_ci<% cache = {"next_index": 0} %>
1083bf215546Sopenharmony_cistatic const nir_search_value_union ${pass_name}_values[] = {
1084bf215546Sopenharmony_ci% for xform in xforms:
1085bf215546Sopenharmony_ci   /* ${xform.search} => ${xform.replace} */
1086bf215546Sopenharmony_ci${xform.search.render(cache)}
1087bf215546Sopenharmony_ci${xform.replace.render(cache)}
1088bf215546Sopenharmony_ci% endfor
1089bf215546Sopenharmony_ci};
1090bf215546Sopenharmony_ci
1091bf215546Sopenharmony_ci% if expression_cond:
1092bf215546Sopenharmony_cistatic const nir_search_expression_cond ${pass_name}_expression_cond[] = {
1093bf215546Sopenharmony_ci% for cond in expression_cond:
1094bf215546Sopenharmony_ci   ${cond[0]},
1095bf215546Sopenharmony_ci% endfor
1096bf215546Sopenharmony_ci};
1097bf215546Sopenharmony_ci% endif
1098bf215546Sopenharmony_ci
1099bf215546Sopenharmony_ci% if variable_cond:
1100bf215546Sopenharmony_cistatic const nir_search_variable_cond ${pass_name}_variable_cond[] = {
1101bf215546Sopenharmony_ci% for cond in variable_cond:
1102bf215546Sopenharmony_ci   ${cond[0]},
1103bf215546Sopenharmony_ci% endfor
1104bf215546Sopenharmony_ci};
1105bf215546Sopenharmony_ci% endif
1106bf215546Sopenharmony_ci
1107bf215546Sopenharmony_cistatic const struct transform ${pass_name}_transforms[] = {
1108bf215546Sopenharmony_ci% for i in automaton.state_patterns:
1109bf215546Sopenharmony_ci% if i is not None:
1110bf215546Sopenharmony_ci   { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} },
1111bf215546Sopenharmony_ci% else:
1112bf215546Sopenharmony_ci   { ~0, ~0, ~0 }, /* Sentinel */
1113bf215546Sopenharmony_ci
1114bf215546Sopenharmony_ci% endif
1115bf215546Sopenharmony_ci% endfor
1116bf215546Sopenharmony_ci};
1117bf215546Sopenharmony_ci
1118bf215546Sopenharmony_cistatic const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = {
1119bf215546Sopenharmony_ci% for op in automaton.opcodes:
1120bf215546Sopenharmony_ci   [${get_c_opcode(op)}] = {
1121bf215546Sopenharmony_ci% if all(e == 0 for e in automaton.filter[op]):
1122bf215546Sopenharmony_ci      .filter = NULL,
1123bf215546Sopenharmony_ci% else:
1124bf215546Sopenharmony_ci      .filter = (const uint16_t []) {
1125bf215546Sopenharmony_ci      % for e in automaton.filter[op]:
1126bf215546Sopenharmony_ci         ${e},
1127bf215546Sopenharmony_ci      % endfor
1128bf215546Sopenharmony_ci      },
1129bf215546Sopenharmony_ci% endif
1130bf215546Sopenharmony_ci      <%
1131bf215546Sopenharmony_ci        num_filtered = len(automaton.rep[op])
1132bf215546Sopenharmony_ci      %>
1133bf215546Sopenharmony_ci      .num_filtered_states = ${num_filtered},
1134bf215546Sopenharmony_ci      .table = (const uint16_t []) {
1135bf215546Sopenharmony_ci      <%
1136bf215546Sopenharmony_ci        num_srcs = len(next(iter(automaton.table[op])))
1137bf215546Sopenharmony_ci      %>
1138bf215546Sopenharmony_ci      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1139bf215546Sopenharmony_ci         ${automaton.table[op][indices]},
1140bf215546Sopenharmony_ci      % endfor
1141bf215546Sopenharmony_ci      },
1142bf215546Sopenharmony_ci   },
1143bf215546Sopenharmony_ci% endfor
1144bf215546Sopenharmony_ci};
1145bf215546Sopenharmony_ci
1146bf215546Sopenharmony_ci/* Mapping from state index to offset in transforms (0 being no transforms) */
1147bf215546Sopenharmony_cistatic const uint16_t ${pass_name}_transform_offsets[] = {
1148bf215546Sopenharmony_ci% for offset in automaton.state_pattern_offsets:
1149bf215546Sopenharmony_ci   ${offset},
1150bf215546Sopenharmony_ci% endfor
1151bf215546Sopenharmony_ci};
1152bf215546Sopenharmony_ci
1153bf215546Sopenharmony_cistatic const nir_algebraic_table ${pass_name}_table = {
1154bf215546Sopenharmony_ci   .transforms = ${pass_name}_transforms,
1155bf215546Sopenharmony_ci   .transform_offsets = ${pass_name}_transform_offsets,
1156bf215546Sopenharmony_ci   .pass_op_table = ${pass_name}_pass_op_table,
1157bf215546Sopenharmony_ci   .values = ${pass_name}_values,
1158bf215546Sopenharmony_ci   .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
1159bf215546Sopenharmony_ci   .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" },
1160bf215546Sopenharmony_ci};
1161bf215546Sopenharmony_ci
1162bf215546Sopenharmony_cibool
1163bf215546Sopenharmony_ci${pass_name}(nir_shader *shader)
1164bf215546Sopenharmony_ci{
1165bf215546Sopenharmony_ci   bool progress = false;
1166bf215546Sopenharmony_ci   bool condition_flags[${len(condition_list)}];
1167bf215546Sopenharmony_ci   const nir_shader_compiler_options *options = shader->options;
1168bf215546Sopenharmony_ci   const shader_info *info = &shader->info;
1169bf215546Sopenharmony_ci   (void) options;
1170bf215546Sopenharmony_ci   (void) info;
1171bf215546Sopenharmony_ci
1172bf215546Sopenharmony_ci   STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values));
1173bf215546Sopenharmony_ci   % for index, condition in enumerate(condition_list):
1174bf215546Sopenharmony_ci   condition_flags[${index}] = ${condition};
1175bf215546Sopenharmony_ci   % endfor
1176bf215546Sopenharmony_ci
1177bf215546Sopenharmony_ci   nir_foreach_function(function, shader) {
1178bf215546Sopenharmony_ci      if (function->impl) {
1179bf215546Sopenharmony_ci         progress |= nir_algebraic_impl(function->impl, condition_flags,
1180bf215546Sopenharmony_ci                                        &${pass_name}_table);
1181bf215546Sopenharmony_ci      }
1182bf215546Sopenharmony_ci   }
1183bf215546Sopenharmony_ci
1184bf215546Sopenharmony_ci   return progress;
1185bf215546Sopenharmony_ci}
1186bf215546Sopenharmony_ci""")
1187bf215546Sopenharmony_ci
1188bf215546Sopenharmony_ci
1189bf215546Sopenharmony_ciclass AlgebraicPass(object):
1190bf215546Sopenharmony_ci   def __init__(self, pass_name, transforms):
1191bf215546Sopenharmony_ci      self.xforms = []
1192bf215546Sopenharmony_ci      self.opcode_xforms = defaultdict(lambda : [])
1193bf215546Sopenharmony_ci      self.pass_name = pass_name
1194bf215546Sopenharmony_ci      self.expression_cond = {}
1195bf215546Sopenharmony_ci      self.variable_cond = {}
1196bf215546Sopenharmony_ci
1197bf215546Sopenharmony_ci      error = False
1198bf215546Sopenharmony_ci
1199bf215546Sopenharmony_ci      for xform in transforms:
1200bf215546Sopenharmony_ci         if not isinstance(xform, SearchAndReplace):
1201bf215546Sopenharmony_ci            try:
1202bf215546Sopenharmony_ci               xform = SearchAndReplace(xform, self)
1203bf215546Sopenharmony_ci            except:
1204bf215546Sopenharmony_ci               print("Failed to parse transformation:", file=sys.stderr)
1205bf215546Sopenharmony_ci               print("  " + str(xform), file=sys.stderr)
1206bf215546Sopenharmony_ci               traceback.print_exc(file=sys.stderr)
1207bf215546Sopenharmony_ci               print('', file=sys.stderr)
1208bf215546Sopenharmony_ci               error = True
1209bf215546Sopenharmony_ci               continue
1210bf215546Sopenharmony_ci
1211bf215546Sopenharmony_ci         self.xforms.append(xform)
1212bf215546Sopenharmony_ci         if xform.search.opcode in conv_opcode_types:
1213bf215546Sopenharmony_ci            dst_type = conv_opcode_types[xform.search.opcode]
1214bf215546Sopenharmony_ci            for size in type_sizes(dst_type):
1215bf215546Sopenharmony_ci               sized_opcode = xform.search.opcode + str(size)
1216bf215546Sopenharmony_ci               self.opcode_xforms[sized_opcode].append(xform)
1217bf215546Sopenharmony_ci         else:
1218bf215546Sopenharmony_ci            self.opcode_xforms[xform.search.opcode].append(xform)
1219bf215546Sopenharmony_ci
1220bf215546Sopenharmony_ci         # Check to make sure the search pattern does not unexpectedly contain
1221bf215546Sopenharmony_ci         # more commutative expressions than match_expression (nir_search.c)
1222bf215546Sopenharmony_ci         # can handle.
1223bf215546Sopenharmony_ci         comm_exprs = xform.search.comm_exprs
1224bf215546Sopenharmony_ci
1225bf215546Sopenharmony_ci         if xform.search.many_commutative_expressions:
1226bf215546Sopenharmony_ci            if comm_exprs <= nir_search_max_comm_ops:
1227bf215546Sopenharmony_ci               print("Transform expected to have too many commutative " \
1228bf215546Sopenharmony_ci                     "expression but did not " \
1229bf215546Sopenharmony_ci                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1230bf215546Sopenharmony_ci                     file=sys.stderr)
1231bf215546Sopenharmony_ci               print("  " + str(xform), file=sys.stderr)
1232bf215546Sopenharmony_ci               traceback.print_exc(file=sys.stderr)
1233bf215546Sopenharmony_ci               print('', file=sys.stderr)
1234bf215546Sopenharmony_ci               error = True
1235bf215546Sopenharmony_ci         else:
1236bf215546Sopenharmony_ci            if comm_exprs > nir_search_max_comm_ops:
1237bf215546Sopenharmony_ci               print("Transformation with too many commutative expressions " \
1238bf215546Sopenharmony_ci                     "({} > {}).  Modify pattern or annotate with " \
1239bf215546Sopenharmony_ci                     "\"many-comm-expr\".".format(comm_exprs,
1240bf215546Sopenharmony_ci                                                  nir_search_max_comm_ops),
1241bf215546Sopenharmony_ci                     file=sys.stderr)
1242bf215546Sopenharmony_ci               print("  " + str(xform.search), file=sys.stderr)
1243bf215546Sopenharmony_ci               print("{}".format(xform.search.cond), file=sys.stderr)
1244bf215546Sopenharmony_ci               error = True
1245bf215546Sopenharmony_ci
1246bf215546Sopenharmony_ci      self.automaton = TreeAutomaton(self.xforms)
1247bf215546Sopenharmony_ci
1248bf215546Sopenharmony_ci      if error:
1249bf215546Sopenharmony_ci         sys.exit(1)
1250bf215546Sopenharmony_ci
1251bf215546Sopenharmony_ci
1252bf215546Sopenharmony_ci   def render(self):
1253bf215546Sopenharmony_ci      return _algebraic_pass_template.render(pass_name=self.pass_name,
1254bf215546Sopenharmony_ci                                             xforms=self.xforms,
1255bf215546Sopenharmony_ci                                             opcode_xforms=self.opcode_xforms,
1256bf215546Sopenharmony_ci                                             condition_list=condition_list,
1257bf215546Sopenharmony_ci                                             automaton=self.automaton,
1258bf215546Sopenharmony_ci                                             expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
1259bf215546Sopenharmony_ci                                             variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
1260bf215546Sopenharmony_ci                                             get_c_opcode=get_c_opcode,
1261bf215546Sopenharmony_ci                                             itertools=itertools)
1262bf215546Sopenharmony_ci
1263bf215546Sopenharmony_ci# The replacement expression isn't necessarily exact if the search expression is exact.
1264bf215546Sopenharmony_cidef ignore_exact(*expr):
1265bf215546Sopenharmony_ci   expr = SearchExpression.create(expr)
1266bf215546Sopenharmony_ci   expr.ignore_exact = True
1267bf215546Sopenharmony_ci   return expr
1268