Lines Matching refs:self
86 def __init__(self):
87 self.names = {}
88 self.ids = itertools.count()
89 self.immutable = False;
91 def __getitem__(self, name):
92 if name not in self.names:
93 assert not self.immutable, "Unknown replacement variable: " + name
94 self.names[name] = next(self.ids)
96 return self.names[name]
98 def lock(self):
99 self.immutable = True
102 def __init__(self, expr):
103 self.opcode = expr[0]
104 self.sources = expr[1:]
105 self.ignore_exact = False
115 def __repr__(self):
116 l = [self.opcode, *self.sources]
117 if self.ignore_exact:
136 def __init__(self, val, name, type_str):
137 self.in_val = str(val)
138 self.name = name
139 self.type_str = type_str
141 def __str__(self):
142 return self.in_val
144 def get_bit_size(self):
153 bit_size = self
160 if bit_size is not self:
161 self._bit_size = bit_size
164 def set_bit_size(self, other):
165 """Make self.get_bit_size() return what other.get_bit_size() return
170 self_bit_size = self.get_bit_size()
179 def type_enum(self):
180 return "nir_search_value_" + self.type_str
183 def c_bit_size(self):
184 bit_size = self.get_bit_size()
219 def render(self, cache):
220 struct_init = self.__template.render(val=self,
227 self.array_index = cache[struct_init]
228 return " /* {} -> {} in the cache */\n".format(self.name,
231 self.array_index = str(cache["next_index"])
232 cache[struct_init] = self.array_index
239 def __init__(self, val, name):
240 Value.__init__(self, val, name, "constant")
244 self.value = ast.literal_eval(m.group('value'))
245 self._bit_size = int(m.group('bits')) if m.group('bits') else None
247 self.value = val
248 self._bit_size = None
250 if isinstance(self.value, bool):
251 assert self._bit_size is None or self._bit_size == 1
252 self._bit_size = 1
254 def hex(self):
255 if isinstance(self.value, (bool)):
256 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
257 if isinstance(self.value, int):
258 return hex(self.value)
259 elif isinstance(self.value, float):
260 return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
264 def type(self):
265 if isinstance(self.value, (bool)):
267 elif isinstance(self.value, int):
269 elif isinstance(self.value, float):
272 def equivalent(self, other):
280 if not isinstance(other, type(self)):
283 return self.value == other.value
294 def __init__(self, val, name, varset, algebraic_pass):
295 Value.__init__(self, val, name, "variable")
301 self.var_name = m.group('name')
306 assert self.var_name.isalpha()
307 assert self.var_name != 'True'
308 assert self.var_name != 'False'
310 self.is_constant = m.group('const') is not None
311 self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond'))
312 self.required_type = m.group('type')
313 self._bit_size = int(m.group('bits')) if m.group('bits') else None
314 self.swiz = m.group('swiz')
316 if self.required_type == 'bool':
317 if self._bit_size is not None:
318 assert self._bit_size in type_sizes(self.required_type)
320 self._bit_size = 1
322 if self.required_type is not None:
323 assert self.required_type in ('float', 'bool', 'int', 'uint')
325 self.index = varset[self.var_name]
327 def type(self):
328 if self.required_type == 'bool':
330 elif self.required_type in ('int', 'uint'):
332 elif self.required_type == 'float':
335 def equivalent(self, other):
343 if not isinstance(other, type(self)):
346 return self.index == other.index
348 def swizzle(self):
349 if self.swiz is not None:
355 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
362 def __init__(self, expr, name_base, varset, algebraic_pass):
363 Value.__init__(self, expr, name_base, "expression")
370 self.opcode = m.group('opcode')
371 self._bit_size = int(m.group('bits')) if m.group('bits') else None
372 self.inexact = m.group('inexact') is not None
373 self.exact = m.group('exact') is not None
374 self.ignore_exact = expr.ignore_exact
375 self.cond = m.group('cond')
377 assert not self.inexact or not self.exact, \
383 self.many_commutative_expressions = False
384 if self.cond and self.cond.find("many-comm-expr") >= 0:
387 c = self.cond[1:-1].split(",")
391 self.cond = c[0] if c else None
392 self.many_commutative_expressions = True
396 self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
398 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
402 assert len(self.sources) <= 4
404 if self.opcode in conv_opcode_types:
405 assert self._bit_size is None, \
409 self.__index_comm_exprs(0)
411 def equivalent(self, other):
422 if not isinstance(other, type(self)):
425 if len(self.sources) != len(other.sources):
428 if self.opcode != other.opcode:
431 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
433 def __index_comm_exprs(self, base_idx):
436 self.comm_exprs = 0
438 # A note about the explicit "len(self.sources)" check. The list of
442 if self.opcode not in conv_opcode_types and \
443 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
444 len(self.sources) >= 2 and \
445 not self.sources[0].equivalent(self.sources[1]):
446 self.comm_expr_idx = base_idx
447 self.comm_exprs += 1
449 self.comm_expr_idx = -1
451 for s in self.sources:
453 s.__index_comm_exprs(base_idx + self.comm_exprs)
454 self.comm_exprs += s.comm_exprs
456 return self.comm_exprs
458 def c_opcode(self):
459 return get_c_opcode(self.opcode)
461 def render(self, cache):
462 srcs = "".join(src.render(cache) for src in self.sources)
463 return srcs + super(Expression, self).render(cache)
551 def __init__(self, varset):
552 self._var_classes = [None] * len(varset.names)
554 def compare_bitsizes(self, a, b):
576 return -1 if self.is_search else None
581 return 1 if self.is_search else None
583 return 0 if self.is_search or a.index == b.index else None
594 def unify_bit_size(self, a, b, error_msg):
597 the bit-sizes of self and other to get a message and raise an error.
604 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
614 def merge_variables(self, val):
622 if self._var_classes[val.index] is None:
623 self._var_classes[val.index] = val
625 other = self._var_classes[val.index]
626 self.unify_bit_size(other, val,
633 self.merge_variables(src)
635 def validate_value(self, val):
655 self.validate_value(val.sources[0])
664 self.validate_value(src)
679 if self.is_search:
680 self.unify_bit_size(first_unsized_src, src,
687 self.unify_bit_size(first_unsized_src, src,
695 if self.is_search:
696 self.unify_bit_size(src, src_type_bits,
702 self.unify_bit_size(src, src_type_bits,
711 if self.is_search:
712 self.unify_bit_size(val, first_unsized_src,
718 self.unify_bit_size(val, first_unsized_src,
725 self.unify_bit_size(val, dst_type_bits,
731 def validate_replace(self, val, search):
741 self.validate_replace(src, search)
743 def validate(self, search, replace):
744 self.is_search = True
745 self.merge_variables(search)
746 self.merge_variables(replace)
747 self.validate_value(search)
749 self.is_search = False
750 self.validate_value(replace)
756 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
765 self.validate_replace(replace, search)
772 def __init__(self, transform, algebraic_pass):
773 self.id = next(_optimization_ids)
778 self.condition = transform[2]
780 self.condition = 'true'
782 if self.condition not in condition_list:
783 condition_list.append(self.condition)
784 self.condition_index = condition_list.index(self.condition)
788 self.search = search
790 self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
795 self.replace = replace
797 self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
799 BitSizeValidator(varset).validate(self.search, self.replace)
817 def __init__(self, transforms):
818 self.patterns = [t.search for t in transforms]
819 self._compute_items()
820 self._build_table()
821 #print('num items: {}'.format(len(set(self.items.values()))))
822 #print('num states: {}'.format(len(self.states)))
823 #for state, patterns in zip(self.states, self.patterns):
832 def __init__(self, iterable=()):
833 self.objects = []
834 self.map = {}
836 self.add(obj)
838 def __getitem__(self, i):
839 return self.objects[i]
841 def __contains__(self, obj):
842 return obj in self.map
844 def __len__(self):
845 return len(self.objects)
847 def __iter__(self):
848 return iter(self.objects)
850 def clear(self):
851 self.objects = []
852 self.map.clear()
854 def index(self, obj):
855 return self.map[obj]
857 def add(self, obj):
858 if obj in self.map:
859 return self.map[obj]
861 index = len(self.objects)
862 self.objects.append(obj)
863 self.map[obj] = index
866 def __repr__(self):
867 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
876 def __init__(self, opcode, children):
877 self.opcode = opcode
878 self.children = children
880 self.patterns = []
883 self.parent_ops = set()
885 def __str__(self):
886 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
888 def __repr__(self):
889 return str(self)
891 def _compute_items(self):
894 self.items = {}
898 self.opcodes = self.IndexMap()
903 item = self.items.setdefault((opcode, children),
904 self.Item(opcode, children))
906 self.items[opcode, (children[1], children[0]) + children[2:]] = item
911 self.wildcard = get_item("__wildcard", ())
912 self.const = get_item("__const", ())
917 return self.const
920 return self.const
924 return self.wildcard
939 self.opcodes.add(opcode)
946 for i, pattern in enumerate(self.patterns):
949 def _build_table(self):
958 self.table = defaultdict(dict)
960 # len(self.states)
961 self.states = self.IndexMap()
963 self.state_patterns = [None]
965 self.state_pattern_offsets = []
967 self.filter = defaultdict(list)
970 # q_{a,j} in the original algorithm is len(self.rep[op]).
971 self.rep = defaultdict(self.IndexMap)
973 # Everything in self.states with a index at least worklist_index is part
979 self.worklist_index = 0
986 new_opcodes = self.IndexMap()
993 while self.worklist_index < len(self.states):
994 state = self.states[self.worklist_index]
1004 self.state_pattern_offsets.append(len(self.state_patterns))
1005 self.state_patterns.extend(patterns)
1006 self.state_patterns.append(None)
1009 self.state_pattern_offsets.append(0)
1013 for op in self.opcodes:
1014 filt = self.filter[op]
1015 rep = self.rep[op]
1023 assert len(filt) == self.worklist_index
1025 self.worklist_index += 1
1033 self.states.add(frozenset((self.wildcard,)))
1034 self.states.add(frozenset((self.const,self.wildcard)))
1039 rep = self.rep[op]
1040 table = self.table[op]
1057 parent = set(self.items[op, item_srcs] for item_srcs in
1058 itertools.product(*srcs) if (op, item_srcs) in self.items)
1062 parent.add(self.wildcard)
1064 table[src_indices] = self.states.add(frozenset(parent))
1190 def __init__(self, pass_name, transforms):
1191 self.xforms = []
1192 self.opcode_xforms = defaultdict(lambda : [])
1193 self.pass_name = pass_name
1194 self.expression_cond = {}
1195 self.variable_cond = {}
1202 xform = SearchAndReplace(xform, self)
1211 self.xforms.append(xform)
1216 self.opcode_xforms[sized_opcode].append(xform)
1218 self.opcode_xforms[xform.search.opcode].append(xform)
1246 self.automaton = TreeAutomaton(self.xforms)
1252 def render(self):
1253 return _algebraic_pass_template.render(pass_name=self.pass_name,
1254 xforms=self.xforms,
1255 opcode_xforms=self.opcode_xforms,
1257 automaton=self.automaton,
1258 expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
1259 variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),