1# encoding=utf-8
2
3# Copyright © 2022 Imagination Technologies Ltd.
4
5# based on anv driver gen_pack_header.py which is:
6# Copyright © 2016 Intel Corporation
7
8# based on v3dv driver gen_pack_header.py which is:
9# Copyright (C) 2016 Broadcom
10
11# Permission is hereby granted, free of charge, to any person obtaining a copy
12# of this software and associated documentation files (the "Software"), to deal
13# in the Software without restriction, including without limitation the rights
14# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15# copies of the Software, and to permit persons to whom the Software is
16# furnished to do so, subject to the following conditions:
17
18# The above copyright notice and this permission notice (including the next
19# paragraph) shall be included in all copies or substantial portions of the
20# Software.
21
22# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28# SOFTWARE.
29
30from __future__ import annotations
31
32import copy
33import os
34import textwrap
35import typing as t
36import xml.parsers.expat as expat
37from abc import ABC
38from ast import literal_eval
39
40
41MIT_LICENSE_COMMENT = """/*
42 * Copyright © %(copyright)s
43 *
44 * Permission is hereby granted, free of charge, to any person obtaining a copy
45 * of this software and associated documentation files (the "Software"), to deal
46 * in the Software without restriction, including without limitation the rights
47 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
48 * copies of the Software, and to permit persons to whom the Software is
49 * furnished to do so, subject to the following conditions:
50 *
51 * The above copyright notice and this permission notice (including the next
52 * paragraph) shall be included in all copies or substantial portions of the
53 * Software.
54 *
55 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
56 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
57 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
58 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
59 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
60 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
61 * SOFTWARE.
62 */"""
63
64PACK_FILE_HEADER = """%(license)s
65
66/* Enums, structures and pack functions for %(platform)s.
67 *
68 * This file has been generated, do not hand edit.
69 */
70
71#ifndef %(guard)s
72#define %(guard)s
73
74#include "csbgen/pvr_packet_helpers.h"
75
76"""
77
78
79def safe_name(name: str) -> str:
80    if not name[0].isalpha():
81        name = "_" + name
82
83    return name
84
85
86def num_from_str(num_str: str) -> int:
87    if num_str.lower().startswith("0x"):
88        return int(num_str, base=16)
89
90    if num_str.startswith("0") and len(num_str) > 1:
91        raise ValueError("Octal numbers not allowed")
92
93    return int(num_str)
94
95
96class Node(ABC):
97    __slots__ = ["parent", "name"]
98
99    parent: Node
100    name: str
101
102    def __init__(self, parent: Node, name: str, *, name_is_safe: bool = False) -> None:
103        self.parent = parent
104        if name_is_safe:
105            self.name = name
106        else:
107            self.name = safe_name(name)
108
109    @property
110    def full_name(self) -> str:
111        if self.name[0] == "_":
112            return self.parent.prefix + self.name.upper()
113
114        return self.parent.prefix + "_" + self.name.upper()
115
116    @property
117    def prefix(self) -> str:
118        return self.parent.prefix
119
120    def add(self, element: Node) -> None:
121        raise RuntimeError("Element cannot be nested in %s. Element Type: %s"
122                           % (type(self).__name__.lower(), type(element).__name__))
123
124
125class Csbgen(Node):
126    __slots__ = ["prefix_field", "filename", "_defines", "_enums", "_structs"]
127
128    prefix_field: str
129    filename: str
130    _defines: t.List[Define]
131    _enums: t.Dict[str, Enum]
132    _structs: t.Dict[str, Struct]
133
134    def __init__(self, name: str, prefix: str, filename: str) -> None:
135        super().__init__(None, name.upper())
136        self.prefix_field = safe_name(prefix.upper())
137        self.filename = filename
138
139        self._defines = []
140        self._enums = {}
141        self._structs = {}
142
143    @property
144    def full_name(self) -> str:
145        return self.name + "_" + self.prefix_field
146
147    @property
148    def prefix(self) -> str:
149        return self.full_name
150
151    def add(self, element: Node) -> None:
152        if isinstance(element, Enum):
153            if element.name in self._enums:
154                raise RuntimeError("Enum redefined. Enum: %s" % element.name)
155
156            self._enums[element.name] = element
157        elif isinstance(element, Struct):
158            if element.name in self._structs:
159                raise RuntimeError("Struct redefined. Struct: %s" % element.name)
160
161            self._structs[element.name] = element
162        elif isinstance(element, Define):
163            define_names = [d.full_name for d in self._defines]
164            if element.full_name in define_names:
165                raise RuntimeError("Define redefined. Define: %s" % element.full_name)
166
167            self._defines.append(element)
168        else:
169            super().add(element)
170
171    def _gen_guard(self) -> str:
172        return os.path.basename(self.filename).replace(".xml", "_h").upper()
173
174    def emit(self) -> None:
175        print(PACK_FILE_HEADER % {
176            "license": MIT_LICENSE_COMMENT % {"copyright": "2022 Imagination Technologies Ltd."},
177            "platform": self.name,
178            "guard": self._gen_guard(),
179        })
180
181        for define in self._defines:
182            define.emit()
183
184        print()
185
186        for enum in self._enums.values():
187            enum.emit()
188
189        for struct in self._structs.values():
190            struct.emit(self)
191
192        print("#endif /* %s */" % self._gen_guard())
193
194    def is_known_struct(self, struct_name: str) -> bool:
195        return struct_name in self._structs.keys()
196
197    def is_known_enum(self, enum_name: str) -> bool:
198        return enum_name in self._enums.keys()
199
200    def get_enum(self, enum_name: str) -> Enum:
201        return self._enums[enum_name]
202
203    def get_struct(self, struct_name: str) -> Struct:
204        return self._structs[struct_name]
205
206
207class Enum(Node):
208    __slots__ = ["_values"]
209
210    _values: t.Dict[str, Value]
211
212    def __init__(self, parent: Node, name: str) -> None:
213        super().__init__(parent, name)
214
215        self._values = {}
216
217        self.parent.add(self)
218
219    # We override prefix so that the values will contain the enum's name too.
220    @property
221    def prefix(self) -> str:
222        return self.full_name
223
224    def get_value(self, value_name: str) -> Value:
225        return self._values[value_name]
226
227    def add(self, element: Node) -> None:
228        if not isinstance(element, Value):
229            super().add(element)
230
231        if element.name in self._values:
232            raise RuntimeError("Value is being redefined. Value: '%s'" % element.name)
233
234        self._values[element.name] = element
235
236    def emit(self) -> None:
237        # This check is invalid if tags other than Value can be nested within an enum.
238        if not self._values.values():
239            raise RuntimeError("Enum definition is empty. Enum: '%s'" % self.full_name)
240
241        print("enum %s {" % self.full_name)
242        for value in self._values.values():
243            value.emit()
244        print("};\n")
245
246
247class Value(Node):
248    __slots__ = ["value"]
249
250    value: int
251
252    def __init__(self, parent: Node, name: str, value: int) -> None:
253        super().__init__(parent, name)
254
255        self.value = value
256
257        self.parent.add(self)
258
259    def emit(self):
260        print("    %-36s = %6d," % (self.full_name, self.value))
261
262
263class Struct(Node):
264    __slots__ = ["length", "size", "_children"]
265
266    length: int
267    size: int
268    _children: t.Dict[str, t.Union[Condition, Field]]
269
270    def __init__(self, parent: Node, name: str, length: int) -> None:
271        super().__init__(parent, name)
272
273        self.length = length
274        self.size = self.length * 32
275
276        if self.length <= 0:
277            raise ValueError("Struct length must be greater than 0. Struct: '%s'." % self.full_name)
278
279        self._children = {}
280
281        self.parent.add(self)
282
283    @property
284    def fields(self) -> t.List[Field]:
285        # TODO: Should we cache? See TODO in equivalent Condition getter.
286
287        fields = []
288        for child in self._children.values():
289            if isinstance(child, Condition):
290                fields += child.fields
291            else:
292                fields.append(child)
293
294        return fields
295
296    @property
297    def prefix(self) -> str:
298        return self.full_name
299
300    def add(self, element: Node) -> None:
301        # We don't support conditions and field having the same name.
302        if isinstance(element, Field):
303            if element.name in self._children.keys():
304                raise ValueError("Field is being redefined. Field: '%s', Struct: '%s'"
305                                 % (element.name, self.full_name))
306
307            self._children[element.name] = element
308
309        elif isinstance(element, Condition):
310            # We only save ifs, and ignore the rest. The rest will be linked to
311            # the if condition so we just need to call emit() on the if and the
312            # rest will also be emitted.
313            if element.type == "if":
314                self._children[element.name] = element
315            else:
316                if element.name not in self._children.keys():
317                    raise RuntimeError("Unknown condition: '%s'" % element.name)
318
319        else:
320            super().add(element)
321
322    def _emit_header(self, root: Csbgen) -> None:
323        default_fields = []
324        for field in (f for f in self.fields if f.default is not None):
325            if field.is_builtin_type:
326                default_fields.append("    .%-35s = %6d" % (field.name, field.default))
327            else:
328                if not root.is_known_enum(field.type):
329                    # Default values should not apply to structures
330                    raise RuntimeError(
331                        "Unknown type. Field: '%s' Type: '%s'"
332                        % (field.name, field.type)
333                    )
334
335                enum = root.get_enum(field.type)
336
337                try:
338                    value = enum.get_value(field.default)
339                except KeyError:
340                    raise ValueError("Unknown enum value. Value: '%s', Enum: '%s', Field: '%s'"
341                                     % (field.default, enum.full_name, field.name))
342
343                default_fields.append("    .%-35s = %s" % (field.name, value.full_name))
344
345        print("#define %-40s\\" % (self.full_name + "_header"))
346        print(",  \\\n".join(default_fields))
347        print("")
348
349    def _emit_helper_macros(self) -> None:
350        for field in (f for f in self.fields if f.defines):
351            print("/* Helper macros for %s */" % field.name)
352
353            for define in field.defines:
354                define.emit()
355
356            print()
357
358    def _emit_pack_function(self, root: Csbgen) -> None:
359        print(textwrap.dedent("""\
360            static inline __attribute__((always_inline)) void
361            %s_pack(__attribute__((unused)) void * restrict dst,
362                  %s__attribute__((unused)) const struct %s * restrict values)
363            {""") % (self.full_name, ' ' * len(self.full_name), self.full_name))
364
365        group = Group(0, 1, self.size, self.fields)
366        dwords, length = group.collect_dwords_and_length()
367        if length:
368            # Cast dst to make header C++ friendly
369            print("    uint32_t * restrict dw = (uint32_t * restrict) dst;")
370
371        group.emit_pack_function(root, dwords, length)
372
373        print("}\n")
374
375    def _emit_unpack_function(self, root: Csbgen) -> None:
376        print(textwrap.dedent("""\
377            static inline __attribute__((always_inline)) void
378            %s_unpack(__attribute__((unused)) const void * restrict src,
379                    %s__attribute__((unused)) struct %s * restrict values)
380            {""") % (self.full_name, ' ' * len(self.full_name), self.full_name))
381
382        group = Group(0, 1, self.size, self.fields)
383        dwords, length = group.collect_dwords_and_length()
384        if length:
385            # Cast src to make header C++ friendly
386            print("    const uint32_t * restrict dw = (const uint32_t * restrict) src;")
387
388        group.emit_unpack_function(root, dwords, length)
389
390        print("}\n")
391
392    def emit(self, root: Csbgen) -> None:
393        print("#define %-33s %6d" % (self.full_name + "_length", self.length))
394
395        self._emit_header(root)
396
397        self._emit_helper_macros()
398
399        print("struct %s {" % self.full_name)
400        for child in self._children.values():
401            child.emit(root)
402        print("};\n")
403
404        self._emit_pack_function(root)
405        self._emit_unpack_function(root)
406
407
408class Field(Node):
409    __slots__ = ["start", "end", "type", "default", "shift", "_defines"]
410
411    start: int
412    end: int
413    type: str
414    default: t.Optional[t.Union[str, int]]
415    shift: t.Optional[int]
416    _defines: t.Dict[str, Define]
417
418    def __init__(self, parent: Node, name: str, start: int, end: int, ty: str, *,
419                 default: t.Optional[str] = None, shift: t.Optional[int] = None) -> None:
420        super().__init__(parent, name)
421
422        self.start = start
423        self.end = end
424        self.type = ty
425
426        self._defines = {}
427
428        self.parent.add(self)
429
430        if self.start > self.end:
431            raise ValueError("Start cannot be after end. Start: %d, End: %d, Field: '%s'"
432                             % (self.start, self.end, self.name))
433
434        if self.type == "bool" and self.end != self.start:
435            raise ValueError("Bool field can only be 1 bit long. Field '%s'" % self.name)
436
437        if default is not None:
438            if not self.is_builtin_type:
439                # Assuming it's an enum type.
440                self.default = safe_name(default)
441            else:
442                self.default = num_from_str(default)
443        else:
444            self.default = None
445
446        if shift is not None:
447            if self.type != "address":
448                raise RuntimeError("Only address fields can have a shift attribute. Field: '%s'" % self.name)
449
450            self.shift = int(shift)
451
452            Define(self, "ALIGNMENT", 2**self.shift)
453        else:
454            if self.type == "address":
455                raise RuntimeError("Field of address type requires a shift attribute. Field '%s'" % self.name)
456
457            self.shift = None
458
459    @property
460    def defines(self) -> t.Iterator[Define]:
461        return self._defines.values()
462
463    # We override prefix so that the defines will contain the field's name too.
464    @property
465    def prefix(self) -> str:
466        return self.full_name
467
468    @property
469    def is_builtin_type(self) -> bool:
470        builtins = {"address", "bool", "float", "mbo", "offset", "int", "uint"}
471        return self.type in builtins
472
473    def _get_c_type(self, root: Csbgen) -> str:
474        if self.type == "address":
475            return "__pvr_address_type"
476        elif self.type == "bool":
477            return "bool"
478        elif self.type == "float":
479            return "float"
480        elif self.type == "offset":
481            return "uint64_t"
482        elif self.type == "int":
483            return "int32_t"
484        elif self.type == "uint":
485            if self.end - self.start < 32:
486                return "uint32_t"
487            elif self.end - self.start < 64:
488                return "uint64_t"
489
490            raise RuntimeError("No known C type found to hold %d bit sized value. Field: '%s'"
491                               % (self.end - self.start, self.name))
492        elif root.is_known_struct(self.type):
493            return "struct " + self.type
494        elif root.is_known_enum(self.type):
495            return "enum " + root.get_enum(self.type).full_name
496        raise RuntimeError("Unknown type. Type: '%s', Field: '%s'" % (self.type, self.name))
497
498    def add(self, element: Node) -> None:
499        if self.type == "mbo":
500            raise RuntimeError("No element can be nested in an mbo field. Element Type: %s, Field: %s"
501                               % (type(element).__name__, self.name))
502
503        if isinstance(element, Define):
504            if element.name in self._defines:
505                raise RuntimeError("Duplicate define. Define: '%s'" % element.name)
506
507            self._defines[element.name] = element
508        else:
509            super().add(element)
510
511    def emit(self, root: Csbgen) -> None:
512        if self.type == "mbo":
513            return
514
515        print("    %-36s %s;" % (self._get_c_type(root), self.name))
516
517
518class Define(Node):
519    __slots__ = ["value"]
520
521    value: int
522
523    def __init__(self, parent: Node, name: str, value: int) -> None:
524        super().__init__(parent, name)
525
526        self.value = value
527
528        self.parent.add(self)
529
530    def emit(self) -> None:
531        print("#define %-40s %d" % (self.full_name, self.value))
532
533
534class Condition(Node):
535    __slots__ = ["type", "_children", "_child_branch"]
536
537    type: str
538    _children: t.Dict[str, t.Union[Condition, Field]]
539    _child_branch: t.Optional[Condition]
540
541    def __init__(self, parent: Node, name: str, ty: str) -> None:
542        super().__init__(parent, name, name_is_safe=True)
543
544        self.type = ty
545        if not Condition._is_valid_type(self.type):
546            raise RuntimeError("Unknown type: '%s'" % self.name)
547
548        self._children = {}
549
550        # This is the link to the next branch for the if statement so either
551        # elif, else, or endif. They themselves will also have a link to the
552        # next branch up until endif which terminates the chain.
553        self._child_branch = None
554
555        self.parent.add(self)
556
557    @property
558    def fields(self) -> t.List[Field]:
559        # TODO: Should we use some kind of state to indicate the all of the
560        # child nodes have been added and then cache the fields in here on the
561        # first call so that we don't have to traverse them again per each call?
562        # The state could be changed wither when we reach the endif and pop from
563        # the context, or when we start emitting.
564
565        fields = []
566
567        for child in self._children.values():
568            if isinstance(child, Condition):
569                fields += child.fields
570            else:
571                fields.append(child)
572
573        if self._child_branch is not None:
574            fields += self._child_branch.fields
575
576        return fields
577
578    @staticmethod
579    def _is_valid_type(ty: str) -> bool:
580        types = {"if", "elif", "else", "endif"}
581        return ty in types
582
583    def _is_compatible_child_branch(self, branch):
584        types = ["if", "elif", "else", "endif"]
585        idx = types.index(self.type)
586        return (branch.type in types[idx + 1:] or
587                self.type == "elif" and branch.type == "elif")
588
589    def _add_branch(self, branch: Condition) -> None:
590        if branch.type == "elif" and branch.name == self.name:
591            raise RuntimeError("Elif branch cannot have same check as previous branch. Check: '%s'" % branch.name)
592
593        if not self._is_compatible_child_branch(branch):
594            raise RuntimeError("Invalid branch. Check: '%s', Type: '%s'" % (branch.name, branch.type))
595
596        self._child_branch = branch
597
598    # Returns the name of the if condition. This is used for elif branches since
599    # they have a different name than the if condition thus we have to traverse
600    # the chain of branches.
601    # This is used to discriminate nested if conditions from branches since
602    # branches like 'endif' and 'else' will have the same name as the 'if' (the
603    # elif is an exception) while nested conditions will have different names.
604    #
605    # TODO: Redo this to improve speed? Would caching this be helpful? We could
606    # just save the name of the if instead of having to walk towards it whenever
607    # a new condition is being added.
608    def _top_branch_name(self) -> str:
609        if self.type == "if":
610            return self.name
611
612        # If we're not an 'if' condition, our parent must be another condition.
613        assert isinstance(self.parent, Condition)
614        return self.parent._top_branch_name()
615
616    def add(self, element: Node) -> None:
617        if isinstance(element, Field):
618            if element.name in self._children.keys():
619                raise ValueError("Duplicate field. Field: '%s'" % element.name)
620
621            self._children[element.name] = element
622        elif isinstance(element, Condition):
623            if element.type == "elif" or self._top_branch_name() == element.name:
624                self._add_branch(element)
625            else:
626                if element.type != "if":
627                    raise RuntimeError("Branch of an unopened if condition. Check: '%s', Type: '%s'."
628                                       % (element.name, element.type))
629
630                # This is a nested condition and we made sure that the name
631                # doesn't match _top_branch_name() so we can recognize the else
632                # and endif.
633                # We recognized the elif by its type however its name differs
634                # from the if condition thus when we add an if condition with
635                # the same name as the elif nested in it, the _top_branch_name()
636                # check doesn't hold true as the name matched the elif and not
637                # the if statement which the elif was a branch of, thus the
638                # nested if condition is not recognized as an invalid branch of
639                # the outer if statement.
640                #   Sample:
641                #   <condition type="if" check="ROGUEXE"/>
642                #       <condition type="elif" check="COMPUTE"/>
643                #           <condition type="if" check="COMPUTE"/>
644                #           <condition type="endif" check="COMPUTE"/>
645                #       <condition type="endif" check="COMPUTE"/>
646                #   <condition type="endif" check="ROGUEXE"/>
647                #
648                # We fix this by checking the if condition name against its
649                # parent.
650                if element.name == self.name:
651                    raise RuntimeError("Invalid if condition. Check: '%s'" % element.name)
652
653                self._children[element.name] = element
654        else:
655            super().add(element)
656
657    def emit(self, root: Csbgen) -> None:
658        if self.type == "if":
659            print("/* if %s is supported use: */" % self.name)
660        elif self.type == "elif":
661            print("/* else if %s is supported use: */" % self.name)
662        elif self.type == "else":
663            print("/* else %s is not-supported use: */" % self.name)
664        elif self.type == "endif":
665            print("/* endif %s */" % self.name)
666            return
667        else:
668            raise RuntimeError("Unknown condition type. Implementation error.")
669
670        for child in self._children.values():
671            child.emit(root)
672
673        self._child_branch.emit(root)
674
675
676class Group:
677    __slots__ = ["start", "count", "size", "fields"]
678
679    start: int
680    count: int
681    size: int
682    fields: t.List[Field]
683
684    def __init__(self, start: int, count: int, size: int, fields) -> None:
685        self.start = start
686        self.count = count
687        self.size = size
688        self.fields = fields
689
690    class DWord:
691        __slots__ = ["size", "fields", "addresses"]
692
693        size: int
694        fields: t.List[Field]
695        addresses: t.List[Field]
696
697        def __init__(self) -> None:
698            self.size = 32
699            self.fields = []
700            self.addresses = []
701
702    def collect_dwords(self, dwords: t.Dict[int, Group.DWord], start: int) -> None:
703        for field in self.fields:
704            index = (start + field.start) // 32
705            if index not in dwords:
706                dwords[index] = self.DWord()
707
708            clone = copy.copy(field)
709            clone.start = clone.start + start
710            clone.end = clone.end + start
711            dwords[index].fields.append(clone)
712
713            if field.type == "address":
714                # assert dwords[index].address == None
715                dwords[index].addresses.append(clone)
716
717            # Coalesce all the dwords covered by this field. The two cases we
718            # handle are where multiple fields are in a 64 bit word (typically
719            # and address and a few bits) or where a single struct field
720            # completely covers multiple dwords.
721            while index < (start + field.end) // 32:
722                if index + 1 in dwords and not dwords[index] == dwords[index + 1]:
723                    dwords[index].fields.extend(dwords[index + 1].fields)
724                    dwords[index].addresses.extend(dwords[index + 1].addresses)
725                dwords[index].size = 64
726                dwords[index + 1] = dwords[index]
727                index = index + 1
728
729    def collect_dwords_and_length(self) -> t.Tuple[t.Dict[int, Group.DWord], int]:
730        dwords = {}
731        self.collect_dwords(dwords, 0)
732
733        # Determine number of dwords in this group. If we have a size, use
734        # that, since that'll account for MBZ dwords at the end of a group
735        # (like dword 8 on BDW+ 3DSTATE_HS). Otherwise, use the largest dword
736        # index we've seen plus one.
737        if self.size > 0:
738            length = self.size // 32
739        elif dwords:
740            length = max(dwords.keys()) + 1
741        else:
742            length = 0
743
744        return dwords, length
745
746    def emit_pack_function(self, root: Csbgen, dwords: t.Dict[int, Group.DWord], length: int) -> None:
747        for index in range(length):
748            # Handle MBZ dwords
749            if index not in dwords:
750                print("")
751                print("    dw[%d] = 0;" % index)
752                continue
753
754            # For 64 bit dwords, we aliased the two dword entries in the dword
755            # dict it occupies. Now that we're emitting the pack function,
756            # skip the duplicate entries.
757            dw = dwords[index]
758            if index > 0 and index - 1 in dwords and dw == dwords[index - 1]:
759                continue
760
761            # Special case: only one field and it's a struct at the beginning
762            # of the dword. In this case we pack directly into the
763            # destination. This is the only way we handle embedded structs
764            # larger than 32 bits.
765            if len(dw.fields) == 1:
766                field = dw.fields[0]
767                if root.is_known_struct(field.type) and field.start % 32 == 0:
768                    print("")
769                    print("    %s_pack(data, &dw[%d], &values->%s);"
770                          % (self.parser.gen_prefix(safe_name(field.type)), index, field.name))
771                    continue
772
773            # Pack any fields of struct type first so we have integer values
774            # to the dword for those fields.
775            field_index = 0
776            for field in dw.fields:
777                if root.is_known_struct(field.type):
778                    print("")
779                    print("    uint32_t v%d_%d;" % (index, field_index))
780                    print("    %s_pack(data, &v%d_%d, &values->%s);"
781                          % (self.parser.gen_prefix(safe_name(field.type)), index, field_index, field.name))
782                    field_index = field_index + 1
783
784            print("")
785            dword_start = index * 32
786            address_count = len(dw.addresses)
787
788            if dw.size == 32 and not dw.addresses:
789                v = None
790                print("    dw[%d] =" % index)
791            elif len(dw.fields) > address_count:
792                v = "v%d" % index
793                print("    const uint%d_t %s =" % (dw.size, v))
794            else:
795                v = "0"
796
797            field_index = 0
798            non_address_fields = []
799            for field in dw.fields:
800                if field.type == "mbo":
801                    non_address_fields.append("__pvr_mbo(%d, %d)"
802                                              % (field.start - dword_start, field.end - dword_start))
803                elif field.type == "address":
804                    pass
805                elif field.type == "uint":
806                    non_address_fields.append("__pvr_uint(values->%s, %d, %d)"
807                                              % (field.name, field.start - dword_start, field.end - dword_start))
808                elif root.is_known_enum(field.type):
809                    non_address_fields.append("__pvr_uint(values->%s, %d, %d)"
810                                              % (field.name, field.start - dword_start, field.end - dword_start))
811                elif field.type == "int":
812                    non_address_fields.append("__pvr_sint(values->%s, %d, %d)"
813                                              % (field.name, field.start - dword_start, field.end - dword_start))
814                elif field.type == "bool":
815                    non_address_fields.append("__pvr_uint(values->%s, %d, %d)"
816                                              % (field.name, field.start - dword_start, field.end - dword_start))
817                elif field.type == "float":
818                    non_address_fields.append("__pvr_float(values->%s)" % field.name)
819                elif field.type == "offset":
820                    non_address_fields.append("__pvr_offset(values->%s, %d, %d)"
821                                              % (field.name, field.start - dword_start, field.end - dword_start))
822                elif field.is_struct_type():
823                    non_address_fields.append("__pvr_uint(v%d_%d, %d, %d)"
824                                              % (index, field_index, field.start - dword_start,
825                                                 field.end - dword_start))
826                    field_index = field_index + 1
827                else:
828                    non_address_fields.append(
829                        "/* unhandled field %s," " type %s */\n" % (field.name, field.type)
830                    )
831
832            if non_address_fields:
833                print(" |\n".join("      " + f for f in non_address_fields) + ";")
834
835            if dw.size == 32:
836                for addr in dw.addresses:
837                    print("    dw[%d] = __pvr_address(values->%s, %d, %d, %d) | %s;"
838                          % (index, addr.name, addr.shift, addr.start - dword_start,
839                             addr.end - dword_start, v))
840                continue
841
842            v_accumulated_addr = ""
843            for i, addr in enumerate(dw.addresses):
844                v_address = "v%d_address" % i
845                v_accumulated_addr += "v%d_address" % i
846                print("    const uint64_t %s =" % v_address)
847                print("      __pvr_address(values->%s, %d, %d, %d);"
848                      % (addr.name, addr.shift, addr.start - dword_start, addr.end - dword_start))
849                if i < (address_count - 1):
850                    v_accumulated_addr += " |\n            "
851
852            if dw.addresses:
853                if len(dw.fields) > address_count:
854                    print("    dw[%d] = %s | %s;" % (index, v_accumulated_addr, v))
855                    print("    dw[%d] = (%s >> 32) | (%s >> 32);" % (index + 1, v_accumulated_addr, v))
856                    continue
857                else:
858                    v = v_accumulated_addr
859
860            print("    dw[%d] = %s;" % (index, v))
861            print("    dw[%d] = %s >> 32;" % (index + 1, v))
862
863    def emit_unpack_function(self, root: Csbgen, dwords: t.Dict[int, Group.DWord], length: int) -> None:
864        for index in range(length):
865            # Ignore MBZ dwords
866            if index not in dwords:
867                continue
868
869            # For 64 bit dwords, we aliased the two dword entries in the dword
870            # dict it occupies. Now that we're emitting the unpack function,
871            # skip the duplicate entries.
872            dw = dwords[index]
873            if index > 0 and index - 1 in dwords and dw == dwords[index - 1]:
874                continue
875
876            # Special case: only one field and it's a struct at the beginning
877            # of the dword. In this case we unpack directly from the
878            # source. This is the only way we handle embedded structs
879            # larger than 32 bits.
880            if len(dw.fields) == 1:
881                field = dw.fields[0]
882                if root.is_known_struct(field.type) and field.start % 32 == 0:
883                    prefix = root.get_struct(field.type)
884                    print("")
885                    print("    %s_unpack(data, &dw[%d], &values->%s);" % (prefix, index, field.name))
886                    continue
887
888            dword_start = index * 32
889
890            if dw.size == 32:
891                v = "dw[%d]" % index
892            elif dw.size == 64:
893                v = "v%d" % index
894                print("    const uint%d_t %s = dw[%d] | ((uint64_t)dw[%d] << 32);" % (dw.size, v, index, index + 1))
895            else:
896                raise RuntimeError("Unsupported dword size %d" % dw.size)
897
898            # Unpack any fields of struct type first.
899            for field_index, field in enumerate(f for f in dw.fields if root.is_known_struct(f.type)):
900                prefix = root.get_struct(field.type).prefix
901                vname = "v%d_%d" % (index, field_index)
902                print("")
903                print("    uint32_t %s = __pvr_uint_unpack(%s, %d, %d);"
904                      % (vname, v, field.start - dword_start, field.end - dword_start))
905                print("    %s_unpack(data, &%s, &values->%s);" % (prefix, vname, field.name))
906
907            for field in dw.fields:
908                dword_field_start = field.start - dword_start
909                dword_field_end = field.end - dword_start
910
911                if field.type == "mbo" or root.is_known_struct(field.type):
912                    continue
913                elif field.type == "uint" or root.is_known_enum(field.type) or field.type == "bool":
914                    print("    values->%s = __pvr_uint_unpack(%s, %d, %d);"
915                          % (field.name, v, dword_field_start, dword_field_end))
916                elif field.type == "int":
917                    print("    values->%s = __pvr_sint_unpack(%s, %d, %d);"
918                          % (field.name, v, dword_field_start, dword_field_end))
919                elif field.type == "float":
920                    print("    values->%s = __pvr_float_unpack(%s);" % (field.name, v))
921                elif field.type == "offset":
922                    print("    values->%s = __pvr_offset_unpack(%s, %d, %d);"
923                          % (field.name, v, dword_field_start, dword_field_end))
924                elif field.type == "address":
925                    print("    values->%s = __pvr_address_unpack(%s, %d, %d, %d);"
926                          % (field.name, v, field.shift, dword_field_start, dword_field_end))
927                else:
928                    print("/* unhandled field %s, type %s */" % (field.name, field.type))
929
930
931class Parser:
932    __slots__ = ["parser", "context", "filename"]
933
934    parser: expat.XMLParserType
935    context: t.List[Node]
936    filename: str
937
938    def __init__(self) -> None:
939        self.parser = expat.ParserCreate()
940        self.parser.StartElementHandler = self.start_element
941        self.parser.EndElementHandler = self.end_element
942
943        self.context = []
944        self.filename = ""
945
946    def start_element(self, name: str, attrs: t.Dict[str, str]) -> None:
947        if name == "csbgen":
948            if self.context:
949                raise RuntimeError(
950                    "Can only have 1 csbgen block and it has "
951                    + "to contain all of the other elements."
952                )
953
954            csbgen = Csbgen(attrs["name"], attrs["prefix"], self.filename)
955            self.context.append(csbgen)
956            return
957
958        parent = self.context[-1]
959
960        if name == "struct":
961            struct = Struct(parent, attrs["name"], int(attrs["length"]))
962            self.context.append(struct)
963
964        elif name == "field":
965            default = None
966            if "default" in attrs.keys():
967                default = attrs["default"]
968
969            shift = None
970            if "shift" in attrs.keys():
971                shift = attrs["shift"]
972
973            field = Field(parent, name=attrs["name"], start=int(attrs["start"]), end=int(attrs["end"]),
974                          ty=attrs["type"], default=default, shift=shift)
975            self.context.append(field)
976
977        elif name == "enum":
978            enum = Enum(parent, attrs["name"])
979            self.context.append(enum)
980
981        elif name == "value":
982            value = Value(parent, attrs["name"], int(literal_eval(attrs["value"])))
983            self.context.append(value)
984
985        elif name == "define":
986            define = Define(parent, attrs["name"], int(literal_eval(attrs["value"])))
987            self.context.append(define)
988
989        elif name == "condition":
990            condition = Condition(parent, name=attrs["check"], ty=attrs["type"])
991
992            # Starting with the if statement we push it in the context. For each
993            # branch following (elif, and else) we assign the top of stack as
994            # its parent, pop() and push the new condition. So per branch we end
995            # up having [..., struct, condition]. We don't push an endif since
996            # it's not supposed to have any children and it's supposed to close
997            # the whole if statement.
998
999            if condition.type != "if":
1000                # Remove the parent condition from the context. We were peeking
1001                # before, now we pop().
1002                self.context.pop()
1003
1004            if condition.type == "endif":
1005                if not isinstance(parent, Condition):
1006                    raise RuntimeError("Cannot close unopened or already closed condition. Condition: '%s'"
1007                                       % condition.name)
1008            else:
1009                self.context.append(condition)
1010
1011        else:
1012            raise RuntimeError("Unknown tag: '%s'" % name)
1013
1014    def end_element(self, name: str) -> None:
1015        if name == "condition":
1016            element = self.context[-1]
1017            if not isinstance(element, Condition) and not isinstance(element, Struct):
1018                raise RuntimeError("Expected condition or struct tag to be closed.")
1019
1020            return
1021
1022        element = self.context.pop()
1023
1024        if name == "struct":
1025            if not isinstance(element, Struct):
1026                raise RuntimeError("Expected struct tag to be closed.")
1027        elif name == "field":
1028            if not isinstance(element, Field):
1029                raise RuntimeError("Expected field tag to be closed.")
1030        elif name == "enum":
1031            if not isinstance(element, Enum):
1032                raise RuntimeError("Expected enum tag to be closed.")
1033        elif name == "value":
1034            if not isinstance(element, Value):
1035                raise RuntimeError("Expected value tag to be closed.")
1036        elif name == "define":
1037            if not isinstance(element, Define):
1038                raise RuntimeError("Expected define tag to be closed.")
1039        elif name == "csbgen":
1040            if not isinstance(element, Csbgen):
1041                raise RuntimeError("Expected csbgen tag to be closed.\nSome tags may have not been closed")
1042
1043            element.emit()
1044        else:
1045            raise RuntimeError("Unknown closing element: '%s'" % name)
1046
1047    def parse(self, filename: str) -> None:
1048        file = open(filename, "rb")
1049        self.filename = filename
1050        self.parser.ParseFile(file)
1051        file.close()
1052
1053
1054if __name__ == "__main__":
1055    import sys
1056
1057    if len(sys.argv) < 2:
1058        print("No input xml file specified")
1059        sys.exit(1)
1060
1061    input_file = sys.argv[1]
1062
1063    p = Parser()
1064    p.parse(input_file)
1065