1bf215546Sopenharmony_ci# 2bf215546Sopenharmony_ci# Copyright (C) 2014 Intel Corporation 3bf215546Sopenharmony_ci# 4bf215546Sopenharmony_ci# Permission is hereby granted, free of charge, to any person obtaining a 5bf215546Sopenharmony_ci# copy of this software and associated documentation files (the "Software"), 6bf215546Sopenharmony_ci# to deal in the Software without restriction, including without limitation 7bf215546Sopenharmony_ci# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8bf215546Sopenharmony_ci# and/or sell copies of the Software, and to permit persons to whom the 9bf215546Sopenharmony_ci# Software is furnished to do so, subject to the following conditions: 10bf215546Sopenharmony_ci# 11bf215546Sopenharmony_ci# The above copyright notice and this permission notice (including the next 12bf215546Sopenharmony_ci# paragraph) shall be included in all copies or substantial portions of the 13bf215546Sopenharmony_ci# Software. 14bf215546Sopenharmony_ci# 15bf215546Sopenharmony_ci# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16bf215546Sopenharmony_ci# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17bf215546Sopenharmony_ci# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18bf215546Sopenharmony_ci# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19bf215546Sopenharmony_ci# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20bf215546Sopenharmony_ci# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21bf215546Sopenharmony_ci# IN THE SOFTWARE. 22bf215546Sopenharmony_ci# 23bf215546Sopenharmony_ci# Authors: 24bf215546Sopenharmony_ci# Jason Ekstrand (jason@jlekstrand.net) 25bf215546Sopenharmony_ci 26bf215546Sopenharmony_ciimport ast 27bf215546Sopenharmony_cifrom collections import defaultdict 28bf215546Sopenharmony_ciimport itertools 29bf215546Sopenharmony_ciimport struct 30bf215546Sopenharmony_ciimport sys 31bf215546Sopenharmony_ciimport mako.template 32bf215546Sopenharmony_ciimport re 33bf215546Sopenharmony_ciimport traceback 34bf215546Sopenharmony_ci 35bf215546Sopenharmony_cifrom nir_opcodes import opcodes, type_sizes 36bf215546Sopenharmony_ci 37bf215546Sopenharmony_ci# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 38bf215546Sopenharmony_cinir_search_max_comm_ops = 8 39bf215546Sopenharmony_ci 40bf215546Sopenharmony_ci# These opcodes are only employed by nir_search. This provides a mapping from 41bf215546Sopenharmony_ci# opcode to destination type. 42bf215546Sopenharmony_ciconv_opcode_types = { 43bf215546Sopenharmony_ci 'i2f' : 'float', 44bf215546Sopenharmony_ci 'u2f' : 'float', 45bf215546Sopenharmony_ci 'f2f' : 'float', 46bf215546Sopenharmony_ci 'f2u' : 'uint', 47bf215546Sopenharmony_ci 'f2i' : 'int', 48bf215546Sopenharmony_ci 'u2u' : 'uint', 49bf215546Sopenharmony_ci 'i2i' : 'int', 50bf215546Sopenharmony_ci 'b2f' : 'float', 51bf215546Sopenharmony_ci 'b2i' : 'int', 52bf215546Sopenharmony_ci 'i2b' : 'bool', 53bf215546Sopenharmony_ci 'f2b' : 'bool', 54bf215546Sopenharmony_ci} 55bf215546Sopenharmony_ci 56bf215546Sopenharmony_cidef get_cond_index(conds, cond): 57bf215546Sopenharmony_ci if cond: 58bf215546Sopenharmony_ci if cond in conds: 59bf215546Sopenharmony_ci return conds[cond] 60bf215546Sopenharmony_ci else: 61bf215546Sopenharmony_ci cond_index = len(conds) 62bf215546Sopenharmony_ci conds[cond] = cond_index 63bf215546Sopenharmony_ci return cond_index 64bf215546Sopenharmony_ci else: 65bf215546Sopenharmony_ci return -1 66bf215546Sopenharmony_ci 67bf215546Sopenharmony_cidef get_c_opcode(op): 68bf215546Sopenharmony_ci if op in conv_opcode_types: 69bf215546Sopenharmony_ci return 'nir_search_op_' + op 70bf215546Sopenharmony_ci else: 71bf215546Sopenharmony_ci return 'nir_op_' + op 72bf215546Sopenharmony_ci 73bf215546Sopenharmony_ci_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 74bf215546Sopenharmony_ci 75bf215546Sopenharmony_cidef type_bits(type_str): 76bf215546Sopenharmony_ci m = _type_re.match(type_str) 77bf215546Sopenharmony_ci assert m.group('type') 78bf215546Sopenharmony_ci 79bf215546Sopenharmony_ci if m.group('bits') is None: 80bf215546Sopenharmony_ci return 0 81bf215546Sopenharmony_ci else: 82bf215546Sopenharmony_ci return int(m.group('bits')) 83bf215546Sopenharmony_ci 84bf215546Sopenharmony_ci# Represents a set of variables, each with a unique id 85bf215546Sopenharmony_ciclass VarSet(object): 86bf215546Sopenharmony_ci def __init__(self): 87bf215546Sopenharmony_ci self.names = {} 88bf215546Sopenharmony_ci self.ids = itertools.count() 89bf215546Sopenharmony_ci self.immutable = False; 90bf215546Sopenharmony_ci 91bf215546Sopenharmony_ci def __getitem__(self, name): 92bf215546Sopenharmony_ci if name not in self.names: 93bf215546Sopenharmony_ci assert not self.immutable, "Unknown replacement variable: " + name 94bf215546Sopenharmony_ci self.names[name] = next(self.ids) 95bf215546Sopenharmony_ci 96bf215546Sopenharmony_ci return self.names[name] 97bf215546Sopenharmony_ci 98bf215546Sopenharmony_ci def lock(self): 99bf215546Sopenharmony_ci self.immutable = True 100bf215546Sopenharmony_ci 101bf215546Sopenharmony_ciclass SearchExpression(object): 102bf215546Sopenharmony_ci def __init__(self, expr): 103bf215546Sopenharmony_ci self.opcode = expr[0] 104bf215546Sopenharmony_ci self.sources = expr[1:] 105bf215546Sopenharmony_ci self.ignore_exact = False 106bf215546Sopenharmony_ci 107bf215546Sopenharmony_ci @staticmethod 108bf215546Sopenharmony_ci def create(val): 109bf215546Sopenharmony_ci if isinstance(val, tuple): 110bf215546Sopenharmony_ci return SearchExpression(val) 111bf215546Sopenharmony_ci else: 112bf215546Sopenharmony_ci assert(isinstance(val, SearchExpression)) 113bf215546Sopenharmony_ci return val 114bf215546Sopenharmony_ci 115bf215546Sopenharmony_ci def __repr__(self): 116bf215546Sopenharmony_ci l = [self.opcode, *self.sources] 117bf215546Sopenharmony_ci if self.ignore_exact: 118bf215546Sopenharmony_ci l.append('ignore_exact') 119bf215546Sopenharmony_ci return repr((*l,)) 120bf215546Sopenharmony_ci 121bf215546Sopenharmony_ciclass Value(object): 122bf215546Sopenharmony_ci @staticmethod 123bf215546Sopenharmony_ci def create(val, name_base, varset, algebraic_pass): 124bf215546Sopenharmony_ci if isinstance(val, bytes): 125bf215546Sopenharmony_ci val = val.decode('utf-8') 126bf215546Sopenharmony_ci 127bf215546Sopenharmony_ci if isinstance(val, tuple) or isinstance(val, SearchExpression): 128bf215546Sopenharmony_ci return Expression(val, name_base, varset, algebraic_pass) 129bf215546Sopenharmony_ci elif isinstance(val, Expression): 130bf215546Sopenharmony_ci return val 131bf215546Sopenharmony_ci elif isinstance(val, str): 132bf215546Sopenharmony_ci return Variable(val, name_base, varset, algebraic_pass) 133bf215546Sopenharmony_ci elif isinstance(val, (bool, float, int)): 134bf215546Sopenharmony_ci return Constant(val, name_base) 135bf215546Sopenharmony_ci 136bf215546Sopenharmony_ci def __init__(self, val, name, type_str): 137bf215546Sopenharmony_ci self.in_val = str(val) 138bf215546Sopenharmony_ci self.name = name 139bf215546Sopenharmony_ci self.type_str = type_str 140bf215546Sopenharmony_ci 141bf215546Sopenharmony_ci def __str__(self): 142bf215546Sopenharmony_ci return self.in_val 143bf215546Sopenharmony_ci 144bf215546Sopenharmony_ci def get_bit_size(self): 145bf215546Sopenharmony_ci """Get the physical bit-size that has been chosen for this value, or if 146bf215546Sopenharmony_ci there is none, the canonical value which currently represents this 147bf215546Sopenharmony_ci bit-size class. Variables will be preferred, i.e. if there are any 148bf215546Sopenharmony_ci variables in the equivalence class, the canonical value will be a 149bf215546Sopenharmony_ci variable. We do this since we'll need to know which variable each value 150bf215546Sopenharmony_ci is equivalent to when constructing the replacement expression. This is 151bf215546Sopenharmony_ci the "find" part of the union-find algorithm. 152bf215546Sopenharmony_ci """ 153bf215546Sopenharmony_ci bit_size = self 154bf215546Sopenharmony_ci 155bf215546Sopenharmony_ci while isinstance(bit_size, Value): 156bf215546Sopenharmony_ci if bit_size._bit_size is None: 157bf215546Sopenharmony_ci break 158bf215546Sopenharmony_ci bit_size = bit_size._bit_size 159bf215546Sopenharmony_ci 160bf215546Sopenharmony_ci if bit_size is not self: 161bf215546Sopenharmony_ci self._bit_size = bit_size 162bf215546Sopenharmony_ci return bit_size 163bf215546Sopenharmony_ci 164bf215546Sopenharmony_ci def set_bit_size(self, other): 165bf215546Sopenharmony_ci """Make self.get_bit_size() return what other.get_bit_size() return 166bf215546Sopenharmony_ci before calling this, or just "other" if it's a concrete bit-size. This is 167bf215546Sopenharmony_ci the "union" part of the union-find algorithm. 168bf215546Sopenharmony_ci """ 169bf215546Sopenharmony_ci 170bf215546Sopenharmony_ci self_bit_size = self.get_bit_size() 171bf215546Sopenharmony_ci other_bit_size = other if isinstance(other, int) else other.get_bit_size() 172bf215546Sopenharmony_ci 173bf215546Sopenharmony_ci if self_bit_size == other_bit_size: 174bf215546Sopenharmony_ci return 175bf215546Sopenharmony_ci 176bf215546Sopenharmony_ci self_bit_size._bit_size = other_bit_size 177bf215546Sopenharmony_ci 178bf215546Sopenharmony_ci @property 179bf215546Sopenharmony_ci def type_enum(self): 180bf215546Sopenharmony_ci return "nir_search_value_" + self.type_str 181bf215546Sopenharmony_ci 182bf215546Sopenharmony_ci @property 183bf215546Sopenharmony_ci def c_bit_size(self): 184bf215546Sopenharmony_ci bit_size = self.get_bit_size() 185bf215546Sopenharmony_ci if isinstance(bit_size, int): 186bf215546Sopenharmony_ci return bit_size 187bf215546Sopenharmony_ci elif isinstance(bit_size, Variable): 188bf215546Sopenharmony_ci return -bit_size.index - 1 189bf215546Sopenharmony_ci else: 190bf215546Sopenharmony_ci # If the bit-size class is neither a variable, nor an actual bit-size, then 191bf215546Sopenharmony_ci # - If it's in the search expression, we don't need to check anything 192bf215546Sopenharmony_ci # - If it's in the replace expression, either it's ambiguous (in which 193bf215546Sopenharmony_ci # case we'd reject it), or it equals the bit-size of the search value 194bf215546Sopenharmony_ci # We represent these cases with a 0 bit-size. 195bf215546Sopenharmony_ci return 0 196bf215546Sopenharmony_ci 197bf215546Sopenharmony_ci __template = mako.template.Template(""" { .${val.type_str} = { 198bf215546Sopenharmony_ci { ${val.type_enum}, ${val.c_bit_size} }, 199bf215546Sopenharmony_ci% if isinstance(val, Constant): 200bf215546Sopenharmony_ci ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 201bf215546Sopenharmony_ci% elif isinstance(val, Variable): 202bf215546Sopenharmony_ci ${val.index}, /* ${val.var_name} */ 203bf215546Sopenharmony_ci ${'true' if val.is_constant else 'false'}, 204bf215546Sopenharmony_ci ${val.type() or 'nir_type_invalid' }, 205bf215546Sopenharmony_ci ${val.cond_index}, 206bf215546Sopenharmony_ci ${val.swizzle()}, 207bf215546Sopenharmony_ci% elif isinstance(val, Expression): 208bf215546Sopenharmony_ci ${'true' if val.inexact else 'false'}, 209bf215546Sopenharmony_ci ${'true' if val.exact else 'false'}, 210bf215546Sopenharmony_ci ${'true' if val.ignore_exact else 'false'}, 211bf215546Sopenharmony_ci ${val.c_opcode()}, 212bf215546Sopenharmony_ci ${val.comm_expr_idx}, ${val.comm_exprs}, 213bf215546Sopenharmony_ci { ${', '.join(src.array_index for src in val.sources)} }, 214bf215546Sopenharmony_ci ${val.cond_index}, 215bf215546Sopenharmony_ci% endif 216bf215546Sopenharmony_ci } }, 217bf215546Sopenharmony_ci""") 218bf215546Sopenharmony_ci 219bf215546Sopenharmony_ci def render(self, cache): 220bf215546Sopenharmony_ci struct_init = self.__template.render(val=self, 221bf215546Sopenharmony_ci Constant=Constant, 222bf215546Sopenharmony_ci Variable=Variable, 223bf215546Sopenharmony_ci Expression=Expression) 224bf215546Sopenharmony_ci if struct_init in cache: 225bf215546Sopenharmony_ci # If it's in the cache, register a name remap in the cache and render 226bf215546Sopenharmony_ci # only a comment saying it's been remapped 227bf215546Sopenharmony_ci self.array_index = cache[struct_init] 228bf215546Sopenharmony_ci return " /* {} -> {} in the cache */\n".format(self.name, 229bf215546Sopenharmony_ci cache[struct_init]) 230bf215546Sopenharmony_ci else: 231bf215546Sopenharmony_ci self.array_index = str(cache["next_index"]) 232bf215546Sopenharmony_ci cache[struct_init] = self.array_index 233bf215546Sopenharmony_ci cache["next_index"] += 1 234bf215546Sopenharmony_ci return struct_init 235bf215546Sopenharmony_ci 236bf215546Sopenharmony_ci_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 237bf215546Sopenharmony_ci 238bf215546Sopenharmony_ciclass Constant(Value): 239bf215546Sopenharmony_ci def __init__(self, val, name): 240bf215546Sopenharmony_ci Value.__init__(self, val, name, "constant") 241bf215546Sopenharmony_ci 242bf215546Sopenharmony_ci if isinstance(val, (str)): 243bf215546Sopenharmony_ci m = _constant_re.match(val) 244bf215546Sopenharmony_ci self.value = ast.literal_eval(m.group('value')) 245bf215546Sopenharmony_ci self._bit_size = int(m.group('bits')) if m.group('bits') else None 246bf215546Sopenharmony_ci else: 247bf215546Sopenharmony_ci self.value = val 248bf215546Sopenharmony_ci self._bit_size = None 249bf215546Sopenharmony_ci 250bf215546Sopenharmony_ci if isinstance(self.value, bool): 251bf215546Sopenharmony_ci assert self._bit_size is None or self._bit_size == 1 252bf215546Sopenharmony_ci self._bit_size = 1 253bf215546Sopenharmony_ci 254bf215546Sopenharmony_ci def hex(self): 255bf215546Sopenharmony_ci if isinstance(self.value, (bool)): 256bf215546Sopenharmony_ci return 'NIR_TRUE' if self.value else 'NIR_FALSE' 257bf215546Sopenharmony_ci if isinstance(self.value, int): 258bf215546Sopenharmony_ci return hex(self.value) 259bf215546Sopenharmony_ci elif isinstance(self.value, float): 260bf215546Sopenharmony_ci return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) 261bf215546Sopenharmony_ci else: 262bf215546Sopenharmony_ci assert False 263bf215546Sopenharmony_ci 264bf215546Sopenharmony_ci def type(self): 265bf215546Sopenharmony_ci if isinstance(self.value, (bool)): 266bf215546Sopenharmony_ci return "nir_type_bool" 267bf215546Sopenharmony_ci elif isinstance(self.value, int): 268bf215546Sopenharmony_ci return "nir_type_int" 269bf215546Sopenharmony_ci elif isinstance(self.value, float): 270bf215546Sopenharmony_ci return "nir_type_float" 271bf215546Sopenharmony_ci 272bf215546Sopenharmony_ci def equivalent(self, other): 273bf215546Sopenharmony_ci """Check that two constants are equivalent. 274bf215546Sopenharmony_ci 275bf215546Sopenharmony_ci This is check is much weaker than equality. One generally cannot be 276bf215546Sopenharmony_ci used in place of the other. Using this implementation for the __eq__ 277bf215546Sopenharmony_ci will break BitSizeValidator. 278bf215546Sopenharmony_ci 279bf215546Sopenharmony_ci """ 280bf215546Sopenharmony_ci if not isinstance(other, type(self)): 281bf215546Sopenharmony_ci return False 282bf215546Sopenharmony_ci 283bf215546Sopenharmony_ci return self.value == other.value 284bf215546Sopenharmony_ci 285bf215546Sopenharmony_ci# The $ at the end forces there to be an error if any part of the string 286bf215546Sopenharmony_ci# doesn't match one of the field patterns. 287bf215546Sopenharmony_ci_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 288bf215546Sopenharmony_ci r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 289bf215546Sopenharmony_ci r"(?P<cond>\([^\)]+\))?" 290bf215546Sopenharmony_ci r"(?P<swiz>\.[xyzw]+)?" 291bf215546Sopenharmony_ci r"$") 292bf215546Sopenharmony_ci 293bf215546Sopenharmony_ciclass Variable(Value): 294bf215546Sopenharmony_ci def __init__(self, val, name, varset, algebraic_pass): 295bf215546Sopenharmony_ci Value.__init__(self, val, name, "variable") 296bf215546Sopenharmony_ci 297bf215546Sopenharmony_ci m = _var_name_re.match(val) 298bf215546Sopenharmony_ci assert m and m.group('name') is not None, \ 299bf215546Sopenharmony_ci "Malformed variable name \"{}\".".format(val) 300bf215546Sopenharmony_ci 301bf215546Sopenharmony_ci self.var_name = m.group('name') 302bf215546Sopenharmony_ci 303bf215546Sopenharmony_ci # Prevent common cases where someone puts quotes around a literal 304bf215546Sopenharmony_ci # constant. If we want to support names that have numeric or 305bf215546Sopenharmony_ci # punctuation characters, we can me the first assertion more flexible. 306bf215546Sopenharmony_ci assert self.var_name.isalpha() 307bf215546Sopenharmony_ci assert self.var_name != 'True' 308bf215546Sopenharmony_ci assert self.var_name != 'False' 309bf215546Sopenharmony_ci 310bf215546Sopenharmony_ci self.is_constant = m.group('const') is not None 311bf215546Sopenharmony_ci self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond')) 312bf215546Sopenharmony_ci self.required_type = m.group('type') 313bf215546Sopenharmony_ci self._bit_size = int(m.group('bits')) if m.group('bits') else None 314bf215546Sopenharmony_ci self.swiz = m.group('swiz') 315bf215546Sopenharmony_ci 316bf215546Sopenharmony_ci if self.required_type == 'bool': 317bf215546Sopenharmony_ci if self._bit_size is not None: 318bf215546Sopenharmony_ci assert self._bit_size in type_sizes(self.required_type) 319bf215546Sopenharmony_ci else: 320bf215546Sopenharmony_ci self._bit_size = 1 321bf215546Sopenharmony_ci 322bf215546Sopenharmony_ci if self.required_type is not None: 323bf215546Sopenharmony_ci assert self.required_type in ('float', 'bool', 'int', 'uint') 324bf215546Sopenharmony_ci 325bf215546Sopenharmony_ci self.index = varset[self.var_name] 326bf215546Sopenharmony_ci 327bf215546Sopenharmony_ci def type(self): 328bf215546Sopenharmony_ci if self.required_type == 'bool': 329bf215546Sopenharmony_ci return "nir_type_bool" 330bf215546Sopenharmony_ci elif self.required_type in ('int', 'uint'): 331bf215546Sopenharmony_ci return "nir_type_int" 332bf215546Sopenharmony_ci elif self.required_type == 'float': 333bf215546Sopenharmony_ci return "nir_type_float" 334bf215546Sopenharmony_ci 335bf215546Sopenharmony_ci def equivalent(self, other): 336bf215546Sopenharmony_ci """Check that two variables are equivalent. 337bf215546Sopenharmony_ci 338bf215546Sopenharmony_ci This is check is much weaker than equality. One generally cannot be 339bf215546Sopenharmony_ci used in place of the other. Using this implementation for the __eq__ 340bf215546Sopenharmony_ci will break BitSizeValidator. 341bf215546Sopenharmony_ci 342bf215546Sopenharmony_ci """ 343bf215546Sopenharmony_ci if not isinstance(other, type(self)): 344bf215546Sopenharmony_ci return False 345bf215546Sopenharmony_ci 346bf215546Sopenharmony_ci return self.index == other.index 347bf215546Sopenharmony_ci 348bf215546Sopenharmony_ci def swizzle(self): 349bf215546Sopenharmony_ci if self.swiz is not None: 350bf215546Sopenharmony_ci swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 351bf215546Sopenharmony_ci 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 352bf215546Sopenharmony_ci 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 353bf215546Sopenharmony_ci 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 354bf215546Sopenharmony_ci 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 355bf215546Sopenharmony_ci return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 356bf215546Sopenharmony_ci return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 357bf215546Sopenharmony_ci 358bf215546Sopenharmony_ci_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 359bf215546Sopenharmony_ci r"(?P<cond>\([^\)]+\))?") 360bf215546Sopenharmony_ci 361bf215546Sopenharmony_ciclass Expression(Value): 362bf215546Sopenharmony_ci def __init__(self, expr, name_base, varset, algebraic_pass): 363bf215546Sopenharmony_ci Value.__init__(self, expr, name_base, "expression") 364bf215546Sopenharmony_ci 365bf215546Sopenharmony_ci expr = SearchExpression.create(expr) 366bf215546Sopenharmony_ci 367bf215546Sopenharmony_ci m = _opcode_re.match(expr.opcode) 368bf215546Sopenharmony_ci assert m and m.group('opcode') is not None 369bf215546Sopenharmony_ci 370bf215546Sopenharmony_ci self.opcode = m.group('opcode') 371bf215546Sopenharmony_ci self._bit_size = int(m.group('bits')) if m.group('bits') else None 372bf215546Sopenharmony_ci self.inexact = m.group('inexact') is not None 373bf215546Sopenharmony_ci self.exact = m.group('exact') is not None 374bf215546Sopenharmony_ci self.ignore_exact = expr.ignore_exact 375bf215546Sopenharmony_ci self.cond = m.group('cond') 376bf215546Sopenharmony_ci 377bf215546Sopenharmony_ci assert not self.inexact or not self.exact, \ 378bf215546Sopenharmony_ci 'Expression cannot be both exact and inexact.' 379bf215546Sopenharmony_ci 380bf215546Sopenharmony_ci # "many-comm-expr" isn't really a condition. It's notification to the 381bf215546Sopenharmony_ci # generator that this pattern is known to have too many commutative 382bf215546Sopenharmony_ci # expressions, and an error should not be generated for this case. 383bf215546Sopenharmony_ci self.many_commutative_expressions = False 384bf215546Sopenharmony_ci if self.cond and self.cond.find("many-comm-expr") >= 0: 385bf215546Sopenharmony_ci # Split the condition into a comma-separated list. Remove 386bf215546Sopenharmony_ci # "many-comm-expr". If there is anything left, put it back together. 387bf215546Sopenharmony_ci c = self.cond[1:-1].split(",") 388bf215546Sopenharmony_ci c.remove("many-comm-expr") 389bf215546Sopenharmony_ci assert(len(c) <= 1) 390bf215546Sopenharmony_ci 391bf215546Sopenharmony_ci self.cond = c[0] if c else None 392bf215546Sopenharmony_ci self.many_commutative_expressions = True 393bf215546Sopenharmony_ci 394bf215546Sopenharmony_ci # Deduplicate references to the condition functions for the expressions 395bf215546Sopenharmony_ci # and save the index for the order they were added. 396bf215546Sopenharmony_ci self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond) 397bf215546Sopenharmony_ci 398bf215546Sopenharmony_ci self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass) 399bf215546Sopenharmony_ci for (i, src) in enumerate(expr.sources) ] 400bf215546Sopenharmony_ci 401bf215546Sopenharmony_ci # nir_search_expression::srcs is hard-coded to 4 402bf215546Sopenharmony_ci assert len(self.sources) <= 4 403bf215546Sopenharmony_ci 404bf215546Sopenharmony_ci if self.opcode in conv_opcode_types: 405bf215546Sopenharmony_ci assert self._bit_size is None, \ 406bf215546Sopenharmony_ci 'Expression cannot use an unsized conversion opcode with ' \ 407bf215546Sopenharmony_ci 'an explicit size; that\'s silly.' 408bf215546Sopenharmony_ci 409bf215546Sopenharmony_ci self.__index_comm_exprs(0) 410bf215546Sopenharmony_ci 411bf215546Sopenharmony_ci def equivalent(self, other): 412bf215546Sopenharmony_ci """Check that two variables are equivalent. 413bf215546Sopenharmony_ci 414bf215546Sopenharmony_ci This is check is much weaker than equality. One generally cannot be 415bf215546Sopenharmony_ci used in place of the other. Using this implementation for the __eq__ 416bf215546Sopenharmony_ci will break BitSizeValidator. 417bf215546Sopenharmony_ci 418bf215546Sopenharmony_ci This implementation does not check for equivalence due to commutativity, 419bf215546Sopenharmony_ci but it could. 420bf215546Sopenharmony_ci 421bf215546Sopenharmony_ci """ 422bf215546Sopenharmony_ci if not isinstance(other, type(self)): 423bf215546Sopenharmony_ci return False 424bf215546Sopenharmony_ci 425bf215546Sopenharmony_ci if len(self.sources) != len(other.sources): 426bf215546Sopenharmony_ci return False 427bf215546Sopenharmony_ci 428bf215546Sopenharmony_ci if self.opcode != other.opcode: 429bf215546Sopenharmony_ci return False 430bf215546Sopenharmony_ci 431bf215546Sopenharmony_ci return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 432bf215546Sopenharmony_ci 433bf215546Sopenharmony_ci def __index_comm_exprs(self, base_idx): 434bf215546Sopenharmony_ci """Recursively count and index commutative expressions 435bf215546Sopenharmony_ci """ 436bf215546Sopenharmony_ci self.comm_exprs = 0 437bf215546Sopenharmony_ci 438bf215546Sopenharmony_ci # A note about the explicit "len(self.sources)" check. The list of 439bf215546Sopenharmony_ci # sources comes from user input, and that input might be bad. Check 440bf215546Sopenharmony_ci # that the expected second source exists before accessing it. Without 441bf215546Sopenharmony_ci # this check, a unit test that does "('iadd', 'a')" will crash. 442bf215546Sopenharmony_ci if self.opcode not in conv_opcode_types and \ 443bf215546Sopenharmony_ci "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 444bf215546Sopenharmony_ci len(self.sources) >= 2 and \ 445bf215546Sopenharmony_ci not self.sources[0].equivalent(self.sources[1]): 446bf215546Sopenharmony_ci self.comm_expr_idx = base_idx 447bf215546Sopenharmony_ci self.comm_exprs += 1 448bf215546Sopenharmony_ci else: 449bf215546Sopenharmony_ci self.comm_expr_idx = -1 450bf215546Sopenharmony_ci 451bf215546Sopenharmony_ci for s in self.sources: 452bf215546Sopenharmony_ci if isinstance(s, Expression): 453bf215546Sopenharmony_ci s.__index_comm_exprs(base_idx + self.comm_exprs) 454bf215546Sopenharmony_ci self.comm_exprs += s.comm_exprs 455bf215546Sopenharmony_ci 456bf215546Sopenharmony_ci return self.comm_exprs 457bf215546Sopenharmony_ci 458bf215546Sopenharmony_ci def c_opcode(self): 459bf215546Sopenharmony_ci return get_c_opcode(self.opcode) 460bf215546Sopenharmony_ci 461bf215546Sopenharmony_ci def render(self, cache): 462bf215546Sopenharmony_ci srcs = "".join(src.render(cache) for src in self.sources) 463bf215546Sopenharmony_ci return srcs + super(Expression, self).render(cache) 464bf215546Sopenharmony_ci 465bf215546Sopenharmony_ciclass BitSizeValidator(object): 466bf215546Sopenharmony_ci """A class for validating bit sizes of expressions. 467bf215546Sopenharmony_ci 468bf215546Sopenharmony_ci NIR supports multiple bit-sizes on expressions in order to handle things 469bf215546Sopenharmony_ci such as fp64. The source and destination of every ALU operation is 470bf215546Sopenharmony_ci assigned a type and that type may or may not specify a bit size. Sources 471bf215546Sopenharmony_ci and destinations whose type does not specify a bit size are considered 472bf215546Sopenharmony_ci "unsized" and automatically take on the bit size of the corresponding 473bf215546Sopenharmony_ci register or SSA value. NIR has two simple rules for bit sizes that are 474bf215546Sopenharmony_ci validated by nir_validator: 475bf215546Sopenharmony_ci 476bf215546Sopenharmony_ci 1) A given SSA def or register has a single bit size that is respected by 477bf215546Sopenharmony_ci everything that reads from it or writes to it. 478bf215546Sopenharmony_ci 479bf215546Sopenharmony_ci 2) The bit sizes of all unsized inputs/outputs on any given ALU 480bf215546Sopenharmony_ci instruction must match. They need not match the sized inputs or 481bf215546Sopenharmony_ci outputs but they must match each other. 482bf215546Sopenharmony_ci 483bf215546Sopenharmony_ci In order to keep nir_algebraic relatively simple and easy-to-use, 484bf215546Sopenharmony_ci nir_search supports a type of bit-size inference based on the two rules 485bf215546Sopenharmony_ci above. This is similar to type inference in many common programming 486bf215546Sopenharmony_ci languages. If, for instance, you are constructing an add operation and you 487bf215546Sopenharmony_ci know the second source is 16-bit, then you know that the other source and 488bf215546Sopenharmony_ci the destination must also be 16-bit. There are, however, cases where this 489bf215546Sopenharmony_ci inference can be ambiguous or contradictory. Consider, for instance, the 490bf215546Sopenharmony_ci following transformation: 491bf215546Sopenharmony_ci 492bf215546Sopenharmony_ci (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 493bf215546Sopenharmony_ci 494bf215546Sopenharmony_ci This transformation can potentially cause a problem because usub_borrow is 495bf215546Sopenharmony_ci well-defined for any bit-size of integer. However, b2i always generates a 496bf215546Sopenharmony_ci 32-bit result so it could end up replacing a 64-bit expression with one 497bf215546Sopenharmony_ci that takes two 64-bit values and produces a 32-bit value. As another 498bf215546Sopenharmony_ci example, consider this expression: 499bf215546Sopenharmony_ci 500bf215546Sopenharmony_ci (('bcsel', a, b, 0), ('iand', a, b)) 501bf215546Sopenharmony_ci 502bf215546Sopenharmony_ci In this case, in the search expression a must be 32-bit but b can 503bf215546Sopenharmony_ci potentially have any bit size. If we had a 64-bit b value, we would end up 504bf215546Sopenharmony_ci trying to and a 32-bit value with a 64-bit value which would be invalid 505bf215546Sopenharmony_ci 506bf215546Sopenharmony_ci This class solves that problem by providing a validation layer that proves 507bf215546Sopenharmony_ci that a given search-and-replace operation is 100% well-defined before we 508bf215546Sopenharmony_ci generate any code. This ensures that bugs are caught at compile time 509bf215546Sopenharmony_ci rather than at run time. 510bf215546Sopenharmony_ci 511bf215546Sopenharmony_ci Each value maintains a "bit-size class", which is either an actual bit size 512bf215546Sopenharmony_ci or an equivalence class with other values that must have the same bit size. 513bf215546Sopenharmony_ci The validator works by combining bit-size classes with each other according 514bf215546Sopenharmony_ci to the NIR rules outlined above, checking that there are no inconsistencies. 515bf215546Sopenharmony_ci When doing this for the replacement expression, we make sure to never change 516bf215546Sopenharmony_ci the equivalence class of any of the search values. We could make the example 517bf215546Sopenharmony_ci transforms above work by doing some extra run-time checking of the search 518bf215546Sopenharmony_ci expression, but we make the user specify those constraints themselves, to 519bf215546Sopenharmony_ci avoid any surprises. Since the replacement bitsizes can only be connected to 520bf215546Sopenharmony_ci the source bitsize via variables (variables must have the same bitsize in 521bf215546Sopenharmony_ci the source and replacment expressions) or the roots of the expression (the 522bf215546Sopenharmony_ci replacement expression must produce the same bit size as the search 523bf215546Sopenharmony_ci expression), we prevent merging a variable with anything when processing the 524bf215546Sopenharmony_ci replacement expression, or specializing the search bitsize 525bf215546Sopenharmony_ci with anything. The former prevents 526bf215546Sopenharmony_ci 527bf215546Sopenharmony_ci (('bcsel', a, b, 0), ('iand', a, b)) 528bf215546Sopenharmony_ci 529bf215546Sopenharmony_ci from being allowed, since we'd have to merge the bitsizes for a and b due to 530bf215546Sopenharmony_ci the 'iand', while the latter prevents 531bf215546Sopenharmony_ci 532bf215546Sopenharmony_ci (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 533bf215546Sopenharmony_ci 534bf215546Sopenharmony_ci from being allowed, since the search expression has the bit size of a and b, 535bf215546Sopenharmony_ci which can't be specialized to 32 which is the bitsize of the replace 536bf215546Sopenharmony_ci expression. It also prevents something like: 537bf215546Sopenharmony_ci 538bf215546Sopenharmony_ci (('b2i', ('i2b', a)), ('ineq', a, 0)) 539bf215546Sopenharmony_ci 540bf215546Sopenharmony_ci since the bitsize of 'b2i', which can be anything, can't be specialized to 541bf215546Sopenharmony_ci the bitsize of a. 542bf215546Sopenharmony_ci 543bf215546Sopenharmony_ci After doing all this, we check that every subexpression of the replacement 544bf215546Sopenharmony_ci was assigned a constant bitsize, the bitsize of a variable, or the bitsize 545bf215546Sopenharmony_ci of the search expresssion, since those are the things that are known when 546bf215546Sopenharmony_ci constructing the replacement expresssion. Finally, we record the bitsize 547bf215546Sopenharmony_ci needed in nir_search_value so that we know what to do when building the 548bf215546Sopenharmony_ci replacement expression. 549bf215546Sopenharmony_ci """ 550bf215546Sopenharmony_ci 551bf215546Sopenharmony_ci def __init__(self, varset): 552bf215546Sopenharmony_ci self._var_classes = [None] * len(varset.names) 553bf215546Sopenharmony_ci 554bf215546Sopenharmony_ci def compare_bitsizes(self, a, b): 555bf215546Sopenharmony_ci """Determines which bitsize class is a specialization of the other, or 556bf215546Sopenharmony_ci whether neither is. When we merge two different bitsizes, the 557bf215546Sopenharmony_ci less-specialized bitsize always points to the more-specialized one, so 558bf215546Sopenharmony_ci that calling get_bit_size() always gets you the most specialized bitsize. 559bf215546Sopenharmony_ci The specialization partial order is given by: 560bf215546Sopenharmony_ci - Physical bitsizes are always the most specialized, and a different 561bf215546Sopenharmony_ci bitsize can never specialize another. 562bf215546Sopenharmony_ci - In the search expression, variables can always be specialized to each 563bf215546Sopenharmony_ci other and to physical bitsizes. In the replace expression, we disallow 564bf215546Sopenharmony_ci this to avoid adding extra constraints to the search expression that 565bf215546Sopenharmony_ci the user didn't specify. 566bf215546Sopenharmony_ci - Expressions and constants without a bitsize can always be specialized to 567bf215546Sopenharmony_ci each other and variables, but not the other way around. 568bf215546Sopenharmony_ci 569bf215546Sopenharmony_ci We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 570bf215546Sopenharmony_ci and None if they are not comparable (neither a <= b nor b <= a). 571bf215546Sopenharmony_ci """ 572bf215546Sopenharmony_ci if isinstance(a, int): 573bf215546Sopenharmony_ci if isinstance(b, int): 574bf215546Sopenharmony_ci return 0 if a == b else None 575bf215546Sopenharmony_ci elif isinstance(b, Variable): 576bf215546Sopenharmony_ci return -1 if self.is_search else None 577bf215546Sopenharmony_ci else: 578bf215546Sopenharmony_ci return -1 579bf215546Sopenharmony_ci elif isinstance(a, Variable): 580bf215546Sopenharmony_ci if isinstance(b, int): 581bf215546Sopenharmony_ci return 1 if self.is_search else None 582bf215546Sopenharmony_ci elif isinstance(b, Variable): 583bf215546Sopenharmony_ci return 0 if self.is_search or a.index == b.index else None 584bf215546Sopenharmony_ci else: 585bf215546Sopenharmony_ci return -1 586bf215546Sopenharmony_ci else: 587bf215546Sopenharmony_ci if isinstance(b, int): 588bf215546Sopenharmony_ci return 1 589bf215546Sopenharmony_ci elif isinstance(b, Variable): 590bf215546Sopenharmony_ci return 1 591bf215546Sopenharmony_ci else: 592bf215546Sopenharmony_ci return 0 593bf215546Sopenharmony_ci 594bf215546Sopenharmony_ci def unify_bit_size(self, a, b, error_msg): 595bf215546Sopenharmony_ci """Record that a must have the same bit-size as b. If both 596bf215546Sopenharmony_ci have been assigned conflicting physical bit-sizes, call "error_msg" with 597bf215546Sopenharmony_ci the bit-sizes of self and other to get a message and raise an error. 598bf215546Sopenharmony_ci In the replace expression, disallow merging variables with other 599bf215546Sopenharmony_ci variables and physical bit-sizes as well. 600bf215546Sopenharmony_ci """ 601bf215546Sopenharmony_ci a_bit_size = a.get_bit_size() 602bf215546Sopenharmony_ci b_bit_size = b if isinstance(b, int) else b.get_bit_size() 603bf215546Sopenharmony_ci 604bf215546Sopenharmony_ci cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 605bf215546Sopenharmony_ci 606bf215546Sopenharmony_ci assert cmp_result is not None, \ 607bf215546Sopenharmony_ci error_msg(a_bit_size, b_bit_size) 608bf215546Sopenharmony_ci 609bf215546Sopenharmony_ci if cmp_result < 0: 610bf215546Sopenharmony_ci b_bit_size.set_bit_size(a) 611bf215546Sopenharmony_ci elif not isinstance(a_bit_size, int): 612bf215546Sopenharmony_ci a_bit_size.set_bit_size(b) 613bf215546Sopenharmony_ci 614bf215546Sopenharmony_ci def merge_variables(self, val): 615bf215546Sopenharmony_ci """Perform the first part of type inference by merging all the different 616bf215546Sopenharmony_ci uses of the same variable. We always do this as if we're in the search 617bf215546Sopenharmony_ci expression, even if we're actually not, since otherwise we'd get errors 618bf215546Sopenharmony_ci if the search expression specified some constraint but the replace 619bf215546Sopenharmony_ci expression didn't, because we'd be merging a variable and a constant. 620bf215546Sopenharmony_ci """ 621bf215546Sopenharmony_ci if isinstance(val, Variable): 622bf215546Sopenharmony_ci if self._var_classes[val.index] is None: 623bf215546Sopenharmony_ci self._var_classes[val.index] = val 624bf215546Sopenharmony_ci else: 625bf215546Sopenharmony_ci other = self._var_classes[val.index] 626bf215546Sopenharmony_ci self.unify_bit_size(other, val, 627bf215546Sopenharmony_ci lambda other_bit_size, bit_size: 628bf215546Sopenharmony_ci 'Variable {} has conflicting bit size requirements: ' \ 629bf215546Sopenharmony_ci 'it must have bit size {} and {}'.format( 630bf215546Sopenharmony_ci val.var_name, other_bit_size, bit_size)) 631bf215546Sopenharmony_ci elif isinstance(val, Expression): 632bf215546Sopenharmony_ci for src in val.sources: 633bf215546Sopenharmony_ci self.merge_variables(src) 634bf215546Sopenharmony_ci 635bf215546Sopenharmony_ci def validate_value(self, val): 636bf215546Sopenharmony_ci """Validate the an expression by performing classic Hindley-Milner 637bf215546Sopenharmony_ci type inference on bitsizes. This will detect if there are any conflicting 638bf215546Sopenharmony_ci requirements, and unify variables so that we know which variables must 639bf215546Sopenharmony_ci have the same bitsize. If we're operating on the replace expression, we 640bf215546Sopenharmony_ci will refuse to merge different variables together or merge a variable 641bf215546Sopenharmony_ci with a constant, in order to prevent surprises due to rules unexpectedly 642bf215546Sopenharmony_ci not matching at runtime. 643bf215546Sopenharmony_ci """ 644bf215546Sopenharmony_ci if not isinstance(val, Expression): 645bf215546Sopenharmony_ci return 646bf215546Sopenharmony_ci 647bf215546Sopenharmony_ci # Generic conversion ops are special in that they have a single unsized 648bf215546Sopenharmony_ci # source and an unsized destination and the two don't have to match. 649bf215546Sopenharmony_ci # This means there's no validation or unioning to do here besides the 650bf215546Sopenharmony_ci # len(val.sources) check. 651bf215546Sopenharmony_ci if val.opcode in conv_opcode_types: 652bf215546Sopenharmony_ci assert len(val.sources) == 1, \ 653bf215546Sopenharmony_ci "Expression {} has {} sources, expected 1".format( 654bf215546Sopenharmony_ci val, len(val.sources)) 655bf215546Sopenharmony_ci self.validate_value(val.sources[0]) 656bf215546Sopenharmony_ci return 657bf215546Sopenharmony_ci 658bf215546Sopenharmony_ci nir_op = opcodes[val.opcode] 659bf215546Sopenharmony_ci assert len(val.sources) == nir_op.num_inputs, \ 660bf215546Sopenharmony_ci "Expression {} has {} sources, expected {}".format( 661bf215546Sopenharmony_ci val, len(val.sources), nir_op.num_inputs) 662bf215546Sopenharmony_ci 663bf215546Sopenharmony_ci for src in val.sources: 664bf215546Sopenharmony_ci self.validate_value(src) 665bf215546Sopenharmony_ci 666bf215546Sopenharmony_ci dst_type_bits = type_bits(nir_op.output_type) 667bf215546Sopenharmony_ci 668bf215546Sopenharmony_ci # First, unify all the sources. That way, an error coming up because two 669bf215546Sopenharmony_ci # sources have an incompatible bit-size won't produce an error message 670bf215546Sopenharmony_ci # involving the destination. 671bf215546Sopenharmony_ci first_unsized_src = None 672bf215546Sopenharmony_ci for src_type, src in zip(nir_op.input_types, val.sources): 673bf215546Sopenharmony_ci src_type_bits = type_bits(src_type) 674bf215546Sopenharmony_ci if src_type_bits == 0: 675bf215546Sopenharmony_ci if first_unsized_src is None: 676bf215546Sopenharmony_ci first_unsized_src = src 677bf215546Sopenharmony_ci continue 678bf215546Sopenharmony_ci 679bf215546Sopenharmony_ci if self.is_search: 680bf215546Sopenharmony_ci self.unify_bit_size(first_unsized_src, src, 681bf215546Sopenharmony_ci lambda first_unsized_src_bit_size, src_bit_size: 682bf215546Sopenharmony_ci 'Source {} of {} must have bit size {}, while source {} ' \ 683bf215546Sopenharmony_ci 'must have incompatible bit size {}'.format( 684bf215546Sopenharmony_ci first_unsized_src, val, first_unsized_src_bit_size, 685bf215546Sopenharmony_ci src, src_bit_size)) 686bf215546Sopenharmony_ci else: 687bf215546Sopenharmony_ci self.unify_bit_size(first_unsized_src, src, 688bf215546Sopenharmony_ci lambda first_unsized_src_bit_size, src_bit_size: 689bf215546Sopenharmony_ci 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 690bf215546Sopenharmony_ci 'of {} may not have the same bit size when building the ' \ 691bf215546Sopenharmony_ci 'replacement expression.'.format( 692bf215546Sopenharmony_ci first_unsized_src, first_unsized_src_bit_size, src, 693bf215546Sopenharmony_ci src_bit_size, val)) 694bf215546Sopenharmony_ci else: 695bf215546Sopenharmony_ci if self.is_search: 696bf215546Sopenharmony_ci self.unify_bit_size(src, src_type_bits, 697bf215546Sopenharmony_ci lambda src_bit_size, unused: 698bf215546Sopenharmony_ci '{} must have {} bits, but as a source of nir_op_{} '\ 699bf215546Sopenharmony_ci 'it must have {} bits'.format( 700bf215546Sopenharmony_ci src, src_bit_size, nir_op.name, src_type_bits)) 701bf215546Sopenharmony_ci else: 702bf215546Sopenharmony_ci self.unify_bit_size(src, src_type_bits, 703bf215546Sopenharmony_ci lambda src_bit_size, unused: 704bf215546Sopenharmony_ci '{} has the bit size of {}, but as a source of ' \ 705bf215546Sopenharmony_ci 'nir_op_{} it must have {} bits, which may not be the ' \ 706bf215546Sopenharmony_ci 'same'.format( 707bf215546Sopenharmony_ci src, src_bit_size, nir_op.name, src_type_bits)) 708bf215546Sopenharmony_ci 709bf215546Sopenharmony_ci if dst_type_bits == 0: 710bf215546Sopenharmony_ci if first_unsized_src is not None: 711bf215546Sopenharmony_ci if self.is_search: 712bf215546Sopenharmony_ci self.unify_bit_size(val, first_unsized_src, 713bf215546Sopenharmony_ci lambda val_bit_size, src_bit_size: 714bf215546Sopenharmony_ci '{} must have the bit size of {}, while its source {} ' \ 715bf215546Sopenharmony_ci 'must have incompatible bit size {}'.format( 716bf215546Sopenharmony_ci val, val_bit_size, first_unsized_src, src_bit_size)) 717bf215546Sopenharmony_ci else: 718bf215546Sopenharmony_ci self.unify_bit_size(val, first_unsized_src, 719bf215546Sopenharmony_ci lambda val_bit_size, src_bit_size: 720bf215546Sopenharmony_ci '{} must have {} bits, but its source {} ' \ 721bf215546Sopenharmony_ci '(bit size of {}) may not have that bit size ' \ 722bf215546Sopenharmony_ci 'when building the replacement.'.format( 723bf215546Sopenharmony_ci val, val_bit_size, first_unsized_src, src_bit_size)) 724bf215546Sopenharmony_ci else: 725bf215546Sopenharmony_ci self.unify_bit_size(val, dst_type_bits, 726bf215546Sopenharmony_ci lambda dst_bit_size, unused: 727bf215546Sopenharmony_ci '{} must have {} bits, but as a destination of nir_op_{} ' \ 728bf215546Sopenharmony_ci 'it must have {} bits'.format( 729bf215546Sopenharmony_ci val, dst_bit_size, nir_op.name, dst_type_bits)) 730bf215546Sopenharmony_ci 731bf215546Sopenharmony_ci def validate_replace(self, val, search): 732bf215546Sopenharmony_ci bit_size = val.get_bit_size() 733bf215546Sopenharmony_ci assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 734bf215546Sopenharmony_ci bit_size == search.get_bit_size(), \ 735bf215546Sopenharmony_ci 'Ambiguous bit size for replacement value {}: ' \ 736bf215546Sopenharmony_ci 'it cannot be deduced from a variable, a fixed bit size ' \ 737bf215546Sopenharmony_ci 'somewhere, or the search expression.'.format(val) 738bf215546Sopenharmony_ci 739bf215546Sopenharmony_ci if isinstance(val, Expression): 740bf215546Sopenharmony_ci for src in val.sources: 741bf215546Sopenharmony_ci self.validate_replace(src, search) 742bf215546Sopenharmony_ci 743bf215546Sopenharmony_ci def validate(self, search, replace): 744bf215546Sopenharmony_ci self.is_search = True 745bf215546Sopenharmony_ci self.merge_variables(search) 746bf215546Sopenharmony_ci self.merge_variables(replace) 747bf215546Sopenharmony_ci self.validate_value(search) 748bf215546Sopenharmony_ci 749bf215546Sopenharmony_ci self.is_search = False 750bf215546Sopenharmony_ci self.validate_value(replace) 751bf215546Sopenharmony_ci 752bf215546Sopenharmony_ci # Check that search is always more specialized than replace. Note that 753bf215546Sopenharmony_ci # we're doing this in replace mode, disallowing merging variables. 754bf215546Sopenharmony_ci search_bit_size = search.get_bit_size() 755bf215546Sopenharmony_ci replace_bit_size = replace.get_bit_size() 756bf215546Sopenharmony_ci cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 757bf215546Sopenharmony_ci 758bf215546Sopenharmony_ci assert cmp_result is not None and cmp_result <= 0, \ 759bf215546Sopenharmony_ci 'The search expression bit size {} and replace expression ' \ 760bf215546Sopenharmony_ci 'bit size {} may not be the same'.format( 761bf215546Sopenharmony_ci search_bit_size, replace_bit_size) 762bf215546Sopenharmony_ci 763bf215546Sopenharmony_ci replace.set_bit_size(search) 764bf215546Sopenharmony_ci 765bf215546Sopenharmony_ci self.validate_replace(replace, search) 766bf215546Sopenharmony_ci 767bf215546Sopenharmony_ci_optimization_ids = itertools.count() 768bf215546Sopenharmony_ci 769bf215546Sopenharmony_cicondition_list = ['true'] 770bf215546Sopenharmony_ci 771bf215546Sopenharmony_ciclass SearchAndReplace(object): 772bf215546Sopenharmony_ci def __init__(self, transform, algebraic_pass): 773bf215546Sopenharmony_ci self.id = next(_optimization_ids) 774bf215546Sopenharmony_ci 775bf215546Sopenharmony_ci search = transform[0] 776bf215546Sopenharmony_ci replace = transform[1] 777bf215546Sopenharmony_ci if len(transform) > 2: 778bf215546Sopenharmony_ci self.condition = transform[2] 779bf215546Sopenharmony_ci else: 780bf215546Sopenharmony_ci self.condition = 'true' 781bf215546Sopenharmony_ci 782bf215546Sopenharmony_ci if self.condition not in condition_list: 783bf215546Sopenharmony_ci condition_list.append(self.condition) 784bf215546Sopenharmony_ci self.condition_index = condition_list.index(self.condition) 785bf215546Sopenharmony_ci 786bf215546Sopenharmony_ci varset = VarSet() 787bf215546Sopenharmony_ci if isinstance(search, Expression): 788bf215546Sopenharmony_ci self.search = search 789bf215546Sopenharmony_ci else: 790bf215546Sopenharmony_ci self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass) 791bf215546Sopenharmony_ci 792bf215546Sopenharmony_ci varset.lock() 793bf215546Sopenharmony_ci 794bf215546Sopenharmony_ci if isinstance(replace, Value): 795bf215546Sopenharmony_ci self.replace = replace 796bf215546Sopenharmony_ci else: 797bf215546Sopenharmony_ci self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass) 798bf215546Sopenharmony_ci 799bf215546Sopenharmony_ci BitSizeValidator(varset).validate(self.search, self.replace) 800bf215546Sopenharmony_ci 801bf215546Sopenharmony_ciclass TreeAutomaton(object): 802bf215546Sopenharmony_ci """This class calculates a bottom-up tree automaton to quickly search for 803bf215546Sopenharmony_ci the left-hand sides of tranforms. Tree automatons are a generalization of 804bf215546Sopenharmony_ci classical NFA's and DFA's, where the transition function determines the 805bf215546Sopenharmony_ci state of the parent node based on the state of its children. We construct a 806bf215546Sopenharmony_ci deterministic automaton to match patterns, using a similar algorithm to the 807bf215546Sopenharmony_ci classical NFA to DFA construction. At the moment, it only matches opcodes 808bf215546Sopenharmony_ci and constants (without checking the actual value), leaving more detailed 809bf215546Sopenharmony_ci checking to the search function which actually checks the leaves. The 810bf215546Sopenharmony_ci automaton acts as a quick filter for the search function, requiring only n 811bf215546Sopenharmony_ci + 1 table lookups for each n-source operation. The implementation is based 812bf215546Sopenharmony_ci on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 813bf215546Sopenharmony_ci In the language of that reference, this is a frontier-to-root deterministic 814bf215546Sopenharmony_ci automaton using only symbol filtering. The filtering is crucial to reduce 815bf215546Sopenharmony_ci both the time taken to generate the tables and the size of the tables. 816bf215546Sopenharmony_ci """ 817bf215546Sopenharmony_ci def __init__(self, transforms): 818bf215546Sopenharmony_ci self.patterns = [t.search for t in transforms] 819bf215546Sopenharmony_ci self._compute_items() 820bf215546Sopenharmony_ci self._build_table() 821bf215546Sopenharmony_ci #print('num items: {}'.format(len(set(self.items.values())))) 822bf215546Sopenharmony_ci #print('num states: {}'.format(len(self.states))) 823bf215546Sopenharmony_ci #for state, patterns in zip(self.states, self.patterns): 824bf215546Sopenharmony_ci # print('{}: num patterns: {}'.format(state, len(patterns))) 825bf215546Sopenharmony_ci 826bf215546Sopenharmony_ci class IndexMap(object): 827bf215546Sopenharmony_ci """An indexed list of objects, where one can either lookup an object by 828bf215546Sopenharmony_ci index or find the index associated to an object quickly using a hash 829bf215546Sopenharmony_ci table. Compared to a list, it has a constant time index(). Compared to a 830bf215546Sopenharmony_ci set, it provides a stable iteration order. 831bf215546Sopenharmony_ci """ 832bf215546Sopenharmony_ci def __init__(self, iterable=()): 833bf215546Sopenharmony_ci self.objects = [] 834bf215546Sopenharmony_ci self.map = {} 835bf215546Sopenharmony_ci for obj in iterable: 836bf215546Sopenharmony_ci self.add(obj) 837bf215546Sopenharmony_ci 838bf215546Sopenharmony_ci def __getitem__(self, i): 839bf215546Sopenharmony_ci return self.objects[i] 840bf215546Sopenharmony_ci 841bf215546Sopenharmony_ci def __contains__(self, obj): 842bf215546Sopenharmony_ci return obj in self.map 843bf215546Sopenharmony_ci 844bf215546Sopenharmony_ci def __len__(self): 845bf215546Sopenharmony_ci return len(self.objects) 846bf215546Sopenharmony_ci 847bf215546Sopenharmony_ci def __iter__(self): 848bf215546Sopenharmony_ci return iter(self.objects) 849bf215546Sopenharmony_ci 850bf215546Sopenharmony_ci def clear(self): 851bf215546Sopenharmony_ci self.objects = [] 852bf215546Sopenharmony_ci self.map.clear() 853bf215546Sopenharmony_ci 854bf215546Sopenharmony_ci def index(self, obj): 855bf215546Sopenharmony_ci return self.map[obj] 856bf215546Sopenharmony_ci 857bf215546Sopenharmony_ci def add(self, obj): 858bf215546Sopenharmony_ci if obj in self.map: 859bf215546Sopenharmony_ci return self.map[obj] 860bf215546Sopenharmony_ci else: 861bf215546Sopenharmony_ci index = len(self.objects) 862bf215546Sopenharmony_ci self.objects.append(obj) 863bf215546Sopenharmony_ci self.map[obj] = index 864bf215546Sopenharmony_ci return index 865bf215546Sopenharmony_ci 866bf215546Sopenharmony_ci def __repr__(self): 867bf215546Sopenharmony_ci return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 868bf215546Sopenharmony_ci 869bf215546Sopenharmony_ci class Item(object): 870bf215546Sopenharmony_ci """This represents an "item" in the language of "Tree Automatons." This 871bf215546Sopenharmony_ci is just a subtree of some pattern, which represents a potential partial 872bf215546Sopenharmony_ci match at runtime. We deduplicate them, so that identical subtrees of 873bf215546Sopenharmony_ci different patterns share the same object, and store some extra 874bf215546Sopenharmony_ci information needed for the main algorithm as well. 875bf215546Sopenharmony_ci """ 876bf215546Sopenharmony_ci def __init__(self, opcode, children): 877bf215546Sopenharmony_ci self.opcode = opcode 878bf215546Sopenharmony_ci self.children = children 879bf215546Sopenharmony_ci # These are the indices of patterns for which this item is the root node. 880bf215546Sopenharmony_ci self.patterns = [] 881bf215546Sopenharmony_ci # This the set of opcodes for parents of this item. Used to speed up 882bf215546Sopenharmony_ci # filtering. 883bf215546Sopenharmony_ci self.parent_ops = set() 884bf215546Sopenharmony_ci 885bf215546Sopenharmony_ci def __str__(self): 886bf215546Sopenharmony_ci return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 887bf215546Sopenharmony_ci 888bf215546Sopenharmony_ci def __repr__(self): 889bf215546Sopenharmony_ci return str(self) 890bf215546Sopenharmony_ci 891bf215546Sopenharmony_ci def _compute_items(self): 892bf215546Sopenharmony_ci """Build a set of all possible items, deduplicating them.""" 893bf215546Sopenharmony_ci # This is a map from (opcode, sources) to item. 894bf215546Sopenharmony_ci self.items = {} 895bf215546Sopenharmony_ci 896bf215546Sopenharmony_ci # The set of all opcodes used by the patterns. Used later to avoid 897bf215546Sopenharmony_ci # building and emitting all the tables for opcodes that aren't used. 898bf215546Sopenharmony_ci self.opcodes = self.IndexMap() 899bf215546Sopenharmony_ci 900bf215546Sopenharmony_ci def get_item(opcode, children, pattern=None): 901bf215546Sopenharmony_ci commutative = len(children) >= 2 \ 902bf215546Sopenharmony_ci and "2src_commutative" in opcodes[opcode].algebraic_properties 903bf215546Sopenharmony_ci item = self.items.setdefault((opcode, children), 904bf215546Sopenharmony_ci self.Item(opcode, children)) 905bf215546Sopenharmony_ci if commutative: 906bf215546Sopenharmony_ci self.items[opcode, (children[1], children[0]) + children[2:]] = item 907bf215546Sopenharmony_ci if pattern is not None: 908bf215546Sopenharmony_ci item.patterns.append(pattern) 909bf215546Sopenharmony_ci return item 910bf215546Sopenharmony_ci 911bf215546Sopenharmony_ci self.wildcard = get_item("__wildcard", ()) 912bf215546Sopenharmony_ci self.const = get_item("__const", ()) 913bf215546Sopenharmony_ci 914bf215546Sopenharmony_ci def process_subpattern(src, pattern=None): 915bf215546Sopenharmony_ci if isinstance(src, Constant): 916bf215546Sopenharmony_ci # Note: we throw away the actual constant value! 917bf215546Sopenharmony_ci return self.const 918bf215546Sopenharmony_ci elif isinstance(src, Variable): 919bf215546Sopenharmony_ci if src.is_constant: 920bf215546Sopenharmony_ci return self.const 921bf215546Sopenharmony_ci else: 922bf215546Sopenharmony_ci # Note: we throw away which variable it is here! This special 923bf215546Sopenharmony_ci # item is equivalent to nu in "Tree Automatons." 924bf215546Sopenharmony_ci return self.wildcard 925bf215546Sopenharmony_ci else: 926bf215546Sopenharmony_ci assert isinstance(src, Expression) 927bf215546Sopenharmony_ci opcode = src.opcode 928bf215546Sopenharmony_ci stripped = opcode.rstrip('0123456789') 929bf215546Sopenharmony_ci if stripped in conv_opcode_types: 930bf215546Sopenharmony_ci # Matches that use conversion opcodes with a specific type, 931bf215546Sopenharmony_ci # like f2b1, are tricky. Either we construct the automaton to 932bf215546Sopenharmony_ci # match specific NIR opcodes like nir_op_f2b1, in which case we 933bf215546Sopenharmony_ci # need to create separate items for each possible NIR opcode 934bf215546Sopenharmony_ci # for patterns that have a generic opcode like f2b, or we 935bf215546Sopenharmony_ci # construct it to match the search opcode, in which case we 936bf215546Sopenharmony_ci # need to map f2b1 to f2b when constructing the automaton. Here 937bf215546Sopenharmony_ci # we do the latter. 938bf215546Sopenharmony_ci opcode = stripped 939bf215546Sopenharmony_ci self.opcodes.add(opcode) 940bf215546Sopenharmony_ci children = tuple(process_subpattern(c) for c in src.sources) 941bf215546Sopenharmony_ci item = get_item(opcode, children, pattern) 942bf215546Sopenharmony_ci for i, child in enumerate(children): 943bf215546Sopenharmony_ci child.parent_ops.add(opcode) 944bf215546Sopenharmony_ci return item 945bf215546Sopenharmony_ci 946bf215546Sopenharmony_ci for i, pattern in enumerate(self.patterns): 947bf215546Sopenharmony_ci process_subpattern(pattern, i) 948bf215546Sopenharmony_ci 949bf215546Sopenharmony_ci def _build_table(self): 950bf215546Sopenharmony_ci """This is the core algorithm which builds up the transition table. It 951bf215546Sopenharmony_ci is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 952bf215546Sopenharmony_ci Comp_a and Filt_{a,i} using integers to identify match sets." It 953bf215546Sopenharmony_ci simultaneously builds up a list of all possible "match sets" or 954bf215546Sopenharmony_ci "states", where each match set represents the set of Item's that match a 955bf215546Sopenharmony_ci given instruction, and builds up the transition table between states. 956bf215546Sopenharmony_ci """ 957bf215546Sopenharmony_ci # Map from opcode + filtered state indices to transitioned state. 958bf215546Sopenharmony_ci self.table = defaultdict(dict) 959bf215546Sopenharmony_ci # Bijection from state to index. q in the original algorithm is 960bf215546Sopenharmony_ci # len(self.states) 961bf215546Sopenharmony_ci self.states = self.IndexMap() 962bf215546Sopenharmony_ci # Lists of pattern matches separated by None 963bf215546Sopenharmony_ci self.state_patterns = [None] 964bf215546Sopenharmony_ci # Offset in the ->transforms table for each state index 965bf215546Sopenharmony_ci self.state_pattern_offsets = [] 966bf215546Sopenharmony_ci # Map from state index to filtered state index for each opcode. 967bf215546Sopenharmony_ci self.filter = defaultdict(list) 968bf215546Sopenharmony_ci # Bijections from filtered state to filtered state index for each 969bf215546Sopenharmony_ci # opcode, called the "representor sets" in the original algorithm. 970bf215546Sopenharmony_ci # q_{a,j} in the original algorithm is len(self.rep[op]). 971bf215546Sopenharmony_ci self.rep = defaultdict(self.IndexMap) 972bf215546Sopenharmony_ci 973bf215546Sopenharmony_ci # Everything in self.states with a index at least worklist_index is part 974bf215546Sopenharmony_ci # of the worklist of newly created states. There is also a worklist of 975bf215546Sopenharmony_ci # newly fitered states for each opcode, for which worklist_indices 976bf215546Sopenharmony_ci # serves a similar purpose. worklist_index corresponds to p in the 977bf215546Sopenharmony_ci # original algorithm, while worklist_indices is p_{a,j} (although since 978bf215546Sopenharmony_ci # we only filter by opcode/symbol, it's really just p_a). 979bf215546Sopenharmony_ci self.worklist_index = 0 980bf215546Sopenharmony_ci worklist_indices = defaultdict(lambda: 0) 981bf215546Sopenharmony_ci 982bf215546Sopenharmony_ci # This is the set of opcodes for which the filtered worklist is non-empty. 983bf215546Sopenharmony_ci # It's used to avoid scanning opcodes for which there is nothing to 984bf215546Sopenharmony_ci # process when building the transition table. It corresponds to new_a in 985bf215546Sopenharmony_ci # the original algorithm. 986bf215546Sopenharmony_ci new_opcodes = self.IndexMap() 987bf215546Sopenharmony_ci 988bf215546Sopenharmony_ci # Process states on the global worklist, filtering them for each opcode, 989bf215546Sopenharmony_ci # updating the filter tables, and updating the filtered worklists if any 990bf215546Sopenharmony_ci # new filtered states are found. Similar to ComputeRepresenterSets() in 991bf215546Sopenharmony_ci # the original algorithm, although that only processes a single state. 992bf215546Sopenharmony_ci def process_new_states(): 993bf215546Sopenharmony_ci while self.worklist_index < len(self.states): 994bf215546Sopenharmony_ci state = self.states[self.worklist_index] 995bf215546Sopenharmony_ci # Calculate pattern matches for this state. Each pattern is 996bf215546Sopenharmony_ci # assigned to a unique item, so we don't have to worry about 997bf215546Sopenharmony_ci # deduplicating them here. However, we do have to sort them so 998bf215546Sopenharmony_ci # that they're visited at runtime in the order they're specified 999bf215546Sopenharmony_ci # in the source. 1000bf215546Sopenharmony_ci patterns = list(sorted(p for item in state for p in item.patterns)) 1001bf215546Sopenharmony_ci 1002bf215546Sopenharmony_ci if patterns: 1003bf215546Sopenharmony_ci # Add our patterns to the global table. 1004bf215546Sopenharmony_ci self.state_pattern_offsets.append(len(self.state_patterns)) 1005bf215546Sopenharmony_ci self.state_patterns.extend(patterns) 1006bf215546Sopenharmony_ci self.state_patterns.append(None) 1007bf215546Sopenharmony_ci else: 1008bf215546Sopenharmony_ci # Point to the initial sentinel in the global table. 1009bf215546Sopenharmony_ci self.state_pattern_offsets.append(0) 1010bf215546Sopenharmony_ci 1011bf215546Sopenharmony_ci # calculate filter table for this state, and update filtered 1012bf215546Sopenharmony_ci # worklists. 1013bf215546Sopenharmony_ci for op in self.opcodes: 1014bf215546Sopenharmony_ci filt = self.filter[op] 1015bf215546Sopenharmony_ci rep = self.rep[op] 1016bf215546Sopenharmony_ci filtered = frozenset(item for item in state if \ 1017bf215546Sopenharmony_ci op in item.parent_ops) 1018bf215546Sopenharmony_ci if filtered in rep: 1019bf215546Sopenharmony_ci rep_index = rep.index(filtered) 1020bf215546Sopenharmony_ci else: 1021bf215546Sopenharmony_ci rep_index = rep.add(filtered) 1022bf215546Sopenharmony_ci new_opcodes.add(op) 1023bf215546Sopenharmony_ci assert len(filt) == self.worklist_index 1024bf215546Sopenharmony_ci filt.append(rep_index) 1025bf215546Sopenharmony_ci self.worklist_index += 1 1026bf215546Sopenharmony_ci 1027bf215546Sopenharmony_ci # There are two start states: one which can only match as a wildcard, 1028bf215546Sopenharmony_ci # and one which can match as a wildcard or constant. These will be the 1029bf215546Sopenharmony_ci # states of intrinsics/other instructions and load_const instructions, 1030bf215546Sopenharmony_ci # respectively. The indices of these must match the definitions of 1031bf215546Sopenharmony_ci # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 1032bf215546Sopenharmony_ci # initialize things correctly. 1033bf215546Sopenharmony_ci self.states.add(frozenset((self.wildcard,))) 1034bf215546Sopenharmony_ci self.states.add(frozenset((self.const,self.wildcard))) 1035bf215546Sopenharmony_ci process_new_states() 1036bf215546Sopenharmony_ci 1037bf215546Sopenharmony_ci while len(new_opcodes) > 0: 1038bf215546Sopenharmony_ci for op in new_opcodes: 1039bf215546Sopenharmony_ci rep = self.rep[op] 1040bf215546Sopenharmony_ci table = self.table[op] 1041bf215546Sopenharmony_ci op_worklist_index = worklist_indices[op] 1042bf215546Sopenharmony_ci if op in conv_opcode_types: 1043bf215546Sopenharmony_ci num_srcs = 1 1044bf215546Sopenharmony_ci else: 1045bf215546Sopenharmony_ci num_srcs = opcodes[op].num_inputs 1046bf215546Sopenharmony_ci 1047bf215546Sopenharmony_ci # Iterate over all possible source combinations where at least one 1048bf215546Sopenharmony_ci # is on the worklist. 1049bf215546Sopenharmony_ci for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1050bf215546Sopenharmony_ci if all(src_idx < op_worklist_index for src_idx in src_indices): 1051bf215546Sopenharmony_ci continue 1052bf215546Sopenharmony_ci 1053bf215546Sopenharmony_ci srcs = tuple(rep[src_idx] for src_idx in src_indices) 1054bf215546Sopenharmony_ci 1055bf215546Sopenharmony_ci # Try all possible pairings of source items and add the 1056bf215546Sopenharmony_ci # corresponding parent items. This is Comp_a from the paper. 1057bf215546Sopenharmony_ci parent = set(self.items[op, item_srcs] for item_srcs in 1058bf215546Sopenharmony_ci itertools.product(*srcs) if (op, item_srcs) in self.items) 1059bf215546Sopenharmony_ci 1060bf215546Sopenharmony_ci # We could always start matching something else with a 1061bf215546Sopenharmony_ci # wildcard. This is Cl from the paper. 1062bf215546Sopenharmony_ci parent.add(self.wildcard) 1063bf215546Sopenharmony_ci 1064bf215546Sopenharmony_ci table[src_indices] = self.states.add(frozenset(parent)) 1065bf215546Sopenharmony_ci worklist_indices[op] = len(rep) 1066bf215546Sopenharmony_ci new_opcodes.clear() 1067bf215546Sopenharmony_ci process_new_states() 1068bf215546Sopenharmony_ci 1069bf215546Sopenharmony_ci_algebraic_pass_template = mako.template.Template(""" 1070bf215546Sopenharmony_ci#include "nir.h" 1071bf215546Sopenharmony_ci#include "nir_builder.h" 1072bf215546Sopenharmony_ci#include "nir_search.h" 1073bf215546Sopenharmony_ci#include "nir_search_helpers.h" 1074bf215546Sopenharmony_ci 1075bf215546Sopenharmony_ci/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1076bf215546Sopenharmony_ci * transforms: 1077bf215546Sopenharmony_ci% for xform in xforms: 1078bf215546Sopenharmony_ci * ${xform.search} => ${xform.replace} 1079bf215546Sopenharmony_ci% endfor 1080bf215546Sopenharmony_ci */ 1081bf215546Sopenharmony_ci 1082bf215546Sopenharmony_ci<% cache = {"next_index": 0} %> 1083bf215546Sopenharmony_cistatic const nir_search_value_union ${pass_name}_values[] = { 1084bf215546Sopenharmony_ci% for xform in xforms: 1085bf215546Sopenharmony_ci /* ${xform.search} => ${xform.replace} */ 1086bf215546Sopenharmony_ci${xform.search.render(cache)} 1087bf215546Sopenharmony_ci${xform.replace.render(cache)} 1088bf215546Sopenharmony_ci% endfor 1089bf215546Sopenharmony_ci}; 1090bf215546Sopenharmony_ci 1091bf215546Sopenharmony_ci% if expression_cond: 1092bf215546Sopenharmony_cistatic const nir_search_expression_cond ${pass_name}_expression_cond[] = { 1093bf215546Sopenharmony_ci% for cond in expression_cond: 1094bf215546Sopenharmony_ci ${cond[0]}, 1095bf215546Sopenharmony_ci% endfor 1096bf215546Sopenharmony_ci}; 1097bf215546Sopenharmony_ci% endif 1098bf215546Sopenharmony_ci 1099bf215546Sopenharmony_ci% if variable_cond: 1100bf215546Sopenharmony_cistatic const nir_search_variable_cond ${pass_name}_variable_cond[] = { 1101bf215546Sopenharmony_ci% for cond in variable_cond: 1102bf215546Sopenharmony_ci ${cond[0]}, 1103bf215546Sopenharmony_ci% endfor 1104bf215546Sopenharmony_ci}; 1105bf215546Sopenharmony_ci% endif 1106bf215546Sopenharmony_ci 1107bf215546Sopenharmony_cistatic const struct transform ${pass_name}_transforms[] = { 1108bf215546Sopenharmony_ci% for i in automaton.state_patterns: 1109bf215546Sopenharmony_ci% if i is not None: 1110bf215546Sopenharmony_ci { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} }, 1111bf215546Sopenharmony_ci% else: 1112bf215546Sopenharmony_ci { ~0, ~0, ~0 }, /* Sentinel */ 1113bf215546Sopenharmony_ci 1114bf215546Sopenharmony_ci% endif 1115bf215546Sopenharmony_ci% endfor 1116bf215546Sopenharmony_ci}; 1117bf215546Sopenharmony_ci 1118bf215546Sopenharmony_cistatic const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = { 1119bf215546Sopenharmony_ci% for op in automaton.opcodes: 1120bf215546Sopenharmony_ci [${get_c_opcode(op)}] = { 1121bf215546Sopenharmony_ci% if all(e == 0 for e in automaton.filter[op]): 1122bf215546Sopenharmony_ci .filter = NULL, 1123bf215546Sopenharmony_ci% else: 1124bf215546Sopenharmony_ci .filter = (const uint16_t []) { 1125bf215546Sopenharmony_ci % for e in automaton.filter[op]: 1126bf215546Sopenharmony_ci ${e}, 1127bf215546Sopenharmony_ci % endfor 1128bf215546Sopenharmony_ci }, 1129bf215546Sopenharmony_ci% endif 1130bf215546Sopenharmony_ci <% 1131bf215546Sopenharmony_ci num_filtered = len(automaton.rep[op]) 1132bf215546Sopenharmony_ci %> 1133bf215546Sopenharmony_ci .num_filtered_states = ${num_filtered}, 1134bf215546Sopenharmony_ci .table = (const uint16_t []) { 1135bf215546Sopenharmony_ci <% 1136bf215546Sopenharmony_ci num_srcs = len(next(iter(automaton.table[op]))) 1137bf215546Sopenharmony_ci %> 1138bf215546Sopenharmony_ci % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1139bf215546Sopenharmony_ci ${automaton.table[op][indices]}, 1140bf215546Sopenharmony_ci % endfor 1141bf215546Sopenharmony_ci }, 1142bf215546Sopenharmony_ci }, 1143bf215546Sopenharmony_ci% endfor 1144bf215546Sopenharmony_ci}; 1145bf215546Sopenharmony_ci 1146bf215546Sopenharmony_ci/* Mapping from state index to offset in transforms (0 being no transforms) */ 1147bf215546Sopenharmony_cistatic const uint16_t ${pass_name}_transform_offsets[] = { 1148bf215546Sopenharmony_ci% for offset in automaton.state_pattern_offsets: 1149bf215546Sopenharmony_ci ${offset}, 1150bf215546Sopenharmony_ci% endfor 1151bf215546Sopenharmony_ci}; 1152bf215546Sopenharmony_ci 1153bf215546Sopenharmony_cistatic const nir_algebraic_table ${pass_name}_table = { 1154bf215546Sopenharmony_ci .transforms = ${pass_name}_transforms, 1155bf215546Sopenharmony_ci .transform_offsets = ${pass_name}_transform_offsets, 1156bf215546Sopenharmony_ci .pass_op_table = ${pass_name}_pass_op_table, 1157bf215546Sopenharmony_ci .values = ${pass_name}_values, 1158bf215546Sopenharmony_ci .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" }, 1159bf215546Sopenharmony_ci .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" }, 1160bf215546Sopenharmony_ci}; 1161bf215546Sopenharmony_ci 1162bf215546Sopenharmony_cibool 1163bf215546Sopenharmony_ci${pass_name}(nir_shader *shader) 1164bf215546Sopenharmony_ci{ 1165bf215546Sopenharmony_ci bool progress = false; 1166bf215546Sopenharmony_ci bool condition_flags[${len(condition_list)}]; 1167bf215546Sopenharmony_ci const nir_shader_compiler_options *options = shader->options; 1168bf215546Sopenharmony_ci const shader_info *info = &shader->info; 1169bf215546Sopenharmony_ci (void) options; 1170bf215546Sopenharmony_ci (void) info; 1171bf215546Sopenharmony_ci 1172bf215546Sopenharmony_ci STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values)); 1173bf215546Sopenharmony_ci % for index, condition in enumerate(condition_list): 1174bf215546Sopenharmony_ci condition_flags[${index}] = ${condition}; 1175bf215546Sopenharmony_ci % endfor 1176bf215546Sopenharmony_ci 1177bf215546Sopenharmony_ci nir_foreach_function(function, shader) { 1178bf215546Sopenharmony_ci if (function->impl) { 1179bf215546Sopenharmony_ci progress |= nir_algebraic_impl(function->impl, condition_flags, 1180bf215546Sopenharmony_ci &${pass_name}_table); 1181bf215546Sopenharmony_ci } 1182bf215546Sopenharmony_ci } 1183bf215546Sopenharmony_ci 1184bf215546Sopenharmony_ci return progress; 1185bf215546Sopenharmony_ci} 1186bf215546Sopenharmony_ci""") 1187bf215546Sopenharmony_ci 1188bf215546Sopenharmony_ci 1189bf215546Sopenharmony_ciclass AlgebraicPass(object): 1190bf215546Sopenharmony_ci def __init__(self, pass_name, transforms): 1191bf215546Sopenharmony_ci self.xforms = [] 1192bf215546Sopenharmony_ci self.opcode_xforms = defaultdict(lambda : []) 1193bf215546Sopenharmony_ci self.pass_name = pass_name 1194bf215546Sopenharmony_ci self.expression_cond = {} 1195bf215546Sopenharmony_ci self.variable_cond = {} 1196bf215546Sopenharmony_ci 1197bf215546Sopenharmony_ci error = False 1198bf215546Sopenharmony_ci 1199bf215546Sopenharmony_ci for xform in transforms: 1200bf215546Sopenharmony_ci if not isinstance(xform, SearchAndReplace): 1201bf215546Sopenharmony_ci try: 1202bf215546Sopenharmony_ci xform = SearchAndReplace(xform, self) 1203bf215546Sopenharmony_ci except: 1204bf215546Sopenharmony_ci print("Failed to parse transformation:", file=sys.stderr) 1205bf215546Sopenharmony_ci print(" " + str(xform), file=sys.stderr) 1206bf215546Sopenharmony_ci traceback.print_exc(file=sys.stderr) 1207bf215546Sopenharmony_ci print('', file=sys.stderr) 1208bf215546Sopenharmony_ci error = True 1209bf215546Sopenharmony_ci continue 1210bf215546Sopenharmony_ci 1211bf215546Sopenharmony_ci self.xforms.append(xform) 1212bf215546Sopenharmony_ci if xform.search.opcode in conv_opcode_types: 1213bf215546Sopenharmony_ci dst_type = conv_opcode_types[xform.search.opcode] 1214bf215546Sopenharmony_ci for size in type_sizes(dst_type): 1215bf215546Sopenharmony_ci sized_opcode = xform.search.opcode + str(size) 1216bf215546Sopenharmony_ci self.opcode_xforms[sized_opcode].append(xform) 1217bf215546Sopenharmony_ci else: 1218bf215546Sopenharmony_ci self.opcode_xforms[xform.search.opcode].append(xform) 1219bf215546Sopenharmony_ci 1220bf215546Sopenharmony_ci # Check to make sure the search pattern does not unexpectedly contain 1221bf215546Sopenharmony_ci # more commutative expressions than match_expression (nir_search.c) 1222bf215546Sopenharmony_ci # can handle. 1223bf215546Sopenharmony_ci comm_exprs = xform.search.comm_exprs 1224bf215546Sopenharmony_ci 1225bf215546Sopenharmony_ci if xform.search.many_commutative_expressions: 1226bf215546Sopenharmony_ci if comm_exprs <= nir_search_max_comm_ops: 1227bf215546Sopenharmony_ci print("Transform expected to have too many commutative " \ 1228bf215546Sopenharmony_ci "expression but did not " \ 1229bf215546Sopenharmony_ci "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1230bf215546Sopenharmony_ci file=sys.stderr) 1231bf215546Sopenharmony_ci print(" " + str(xform), file=sys.stderr) 1232bf215546Sopenharmony_ci traceback.print_exc(file=sys.stderr) 1233bf215546Sopenharmony_ci print('', file=sys.stderr) 1234bf215546Sopenharmony_ci error = True 1235bf215546Sopenharmony_ci else: 1236bf215546Sopenharmony_ci if comm_exprs > nir_search_max_comm_ops: 1237bf215546Sopenharmony_ci print("Transformation with too many commutative expressions " \ 1238bf215546Sopenharmony_ci "({} > {}). Modify pattern or annotate with " \ 1239bf215546Sopenharmony_ci "\"many-comm-expr\".".format(comm_exprs, 1240bf215546Sopenharmony_ci nir_search_max_comm_ops), 1241bf215546Sopenharmony_ci file=sys.stderr) 1242bf215546Sopenharmony_ci print(" " + str(xform.search), file=sys.stderr) 1243bf215546Sopenharmony_ci print("{}".format(xform.search.cond), file=sys.stderr) 1244bf215546Sopenharmony_ci error = True 1245bf215546Sopenharmony_ci 1246bf215546Sopenharmony_ci self.automaton = TreeAutomaton(self.xforms) 1247bf215546Sopenharmony_ci 1248bf215546Sopenharmony_ci if error: 1249bf215546Sopenharmony_ci sys.exit(1) 1250bf215546Sopenharmony_ci 1251bf215546Sopenharmony_ci 1252bf215546Sopenharmony_ci def render(self): 1253bf215546Sopenharmony_ci return _algebraic_pass_template.render(pass_name=self.pass_name, 1254bf215546Sopenharmony_ci xforms=self.xforms, 1255bf215546Sopenharmony_ci opcode_xforms=self.opcode_xforms, 1256bf215546Sopenharmony_ci condition_list=condition_list, 1257bf215546Sopenharmony_ci automaton=self.automaton, 1258bf215546Sopenharmony_ci expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), 1259bf215546Sopenharmony_ci variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), 1260bf215546Sopenharmony_ci get_c_opcode=get_c_opcode, 1261bf215546Sopenharmony_ci itertools=itertools) 1262bf215546Sopenharmony_ci 1263bf215546Sopenharmony_ci# The replacement expression isn't necessarily exact if the search expression is exact. 1264bf215546Sopenharmony_cidef ignore_exact(*expr): 1265bf215546Sopenharmony_ci expr = SearchExpression.create(expr) 1266bf215546Sopenharmony_ci expr.ignore_exact = True 1267bf215546Sopenharmony_ci return expr 1268