1#encoding=utf-8 2 3# Copyright (C) 2021 Collabora, Ltd. 4# 5# Permission is hereby granted, free of charge, to any person obtaining a 6# copy of this software and associated documentation files (the "Software"), 7# to deal in the Software without restriction, including without limitation 8# the rights to use, copy, modify, merge, publish, distribute, sublicense, 9# and/or sell copies of the Software, and to permit persons to whom the 10# Software is furnished to do so, subject to the following conditions: 11# 12# The above copyright notice and this permission notice (including the next 13# paragraph) shall be included in all copies or substantial portions of the 14# Software. 15# 16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 22# IN THE SOFTWARE. 23 24import argparse 25import sys 26import struct 27from valhall import instructions, enums, immediates, typesize 28 29LINE = '' 30 31class ParseError(Exception): 32 def __init__(self, error): 33 self.error = error 34 35class FAUState: 36 def __init__(self, message = False): 37 self.message = message 38 self.page = None 39 self.words = set() 40 self.buffer = set() 41 42 def set_page(self, page): 43 assert(page <= 3) 44 die_if(self.page is not None and self.page != page, 'Mismatched pages') 45 self.page = page 46 47 def push(self, source): 48 if not (source & (1 << 7)): 49 # Skip registers 50 return 51 52 self.buffer.add(source) 53 die_if(len(self.buffer) > 2, "Overflowed FAU buffer") 54 55 if (source >> 5) == 0b110: 56 # Small constants need to check if the buffer overflows but no else 57 return 58 59 slot = (source >> 1) 60 61 self.words.add(source) 62 63 # Check the encoded slots 64 slots = set([(x >> 1) for x in self.words]) 65 die_if(len(slots) > (2 if self.message else 1), 'Too many FAU slots') 66 die_if(len(self.words) > (3 if self.message else 2), 'Too many FAU words') 67 68# When running standalone, exit with the error since we're dealing with a 69# human. Otherwise raise a Python exception so the test harness can handle it. 70def die(s): 71 if __name__ == "__main__": 72 print(LINE) 73 print(s) 74 sys.exit(1) 75 else: 76 raise ParseError(s) 77 78def die_if(cond, s): 79 if cond: 80 die(s) 81 82def parse_int(s, minimum, maximum): 83 try: 84 number = int(s, base = 0) 85 except ValueError: 86 die(f"Expected number {s}") 87 88 if number > maximum or number < minimum: 89 die(f"Range error on {s}") 90 91 return number 92 93def encode_source(op, fau): 94 if op[0] == '^': 95 die_if(op[1] != 'r', f"Expected register after discard {op}") 96 return parse_int(op[2:], 0, 63) | 0x40 97 elif op[0] == 'r': 98 return parse_int(op[1:], 0, 63) 99 elif op[0] == 'u': 100 val = parse_int(op[1:], 0, 127) 101 fau.set_page(val >> 6) 102 return (val & 0x3F) | 0x80 103 elif op[0] == 'i': 104 return int(op[3:]) | 0xC0 105 elif op.startswith('0x'): 106 try: 107 val = int(op, base=0) 108 except ValueError: 109 die('Expected value') 110 111 die_if(val not in immediates, 'Unexpected immediate value') 112 return immediates.index(val) | 0xC0 113 else: 114 for i in [0, 1, 3]: 115 if op in enums[f'fau_special_page_{i}'].bare_values: 116 idx = 32 + (enums[f'fau_special_page_{i}'].bare_values.index(op) << 1) 117 fau.set_page(i) 118 return idx | 0xC0 119 120 die('Invalid operand') 121 122 123def encode_dest(op): 124 die_if(op[0] != 'r', f"Expected register destination {op}") 125 126 parts = op.split(".") 127 reg = parts[0] 128 129 # Default to writing in full 130 wrmask = 0x3 131 132 if len(parts) > 1: 133 WMASKS = ["h0", "h1"] 134 die_if(len(parts) > 2, "Too many modifiers") 135 mask = parts[1]; 136 die_if(mask not in WMASKS, "Expected a write mask") 137 wrmask = 1 << WMASKS.index(mask) 138 139 return parse_int(reg[1:], 0, 63) | (wrmask << 6) 140 141def parse_asm(line): 142 global LINE 143 LINE = line # For better errors 144 encoded = 0 145 146 # Figure out mnemonic 147 head = line.split(" ")[0] 148 opts = [ins for ins in instructions if head.startswith(ins.name)] 149 opts = sorted(opts, key=lambda x: len(x.name), reverse=True) 150 151 if len(opts) == 0: 152 die(f"No known mnemonic for {head}") 153 154 if len(opts) > 1 and len(opts[0].name) == len(opts[1].name): 155 print(f"Ambiguous mnemonic for {head}") 156 print(f"Options:") 157 for ins in opts: 158 print(f" {ins}") 159 sys.exit(1) 160 161 ins = opts[0] 162 163 # Split off modifiers 164 if len(head) > len(ins.name) and head[len(ins.name)] != '.': 165 die(f"Expected . after instruction in {head}") 166 167 mods = head[len(ins.name) + 1:].split(".") 168 modifier_map = {} 169 170 tail = line[(len(head) + 1):] 171 operands = [x.strip() for x in tail.split(",") if len(x.strip()) > 0] 172 expected_op_count = len(ins.srcs) + len(ins.dests) + len(ins.immediates) + len(ins.staging) 173 if len(operands) != expected_op_count: 174 die(f"Wrong number of operands in {line}, expected {expected_op_count}, got {len(operands)} {operands}") 175 176 # Encode each operand 177 for i, (op, sr) in enumerate(zip(operands, ins.staging)): 178 die_if(op[0] != '@', f'Expected staging register, got {op}') 179 parts = op[1:].split(':') 180 181 if op == '@': 182 parts = [] 183 184 die_if(any([x[0] != 'r' for x in parts]), f'Expected registers, got {op}') 185 regs = [parse_int(x[1:], 0, 63) for x in parts] 186 187 extended_write = "staging_register_write_count" in [x.name for x in ins.modifiers] and sr.write 188 max_sr_count = 8 if extended_write else 7 189 190 sr_count = len(regs) 191 die_if(sr_count > max_sr_count, f'Too many staging registers {sr_count}') 192 193 base = regs[0] if len(regs) > 0 else 0 194 die_if(any([reg != (base + i) for i, reg in enumerate(regs)]), 195 'Expected consecutive staging registers, got {op}') 196 die_if(sr_count > 1 and (base % 2) != 0, 197 'Consecutive staging registers must be aligned to a register pair') 198 199 if sr.count == 0: 200 if "staging_register_write_count" in [x.name for x in ins.modifiers] and sr.write: 201 modifier_map["staging_register_write_count"] = sr_count - 1 202 else: 203 assert "staging_register_count" in [x.name for x in ins.modifiers] 204 modifier_map["staging_register_count"] = sr_count 205 else: 206 die_if(sr_count != sr.count, f"Expected {sr.count} staging registers, got {sr_count}") 207 208 encoded |= ((sr.encoded_flags | base) << sr.start) 209 operands = operands[len(ins.staging):] 210 211 for op, dest in zip(operands, ins.dests): 212 encoded |= encode_dest(op) << 40 213 operands = operands[len(ins.dests):] 214 215 if len(ins.dests) == 0 and len(ins.staging) == 0: 216 # Set a placeholder writemask to prevent encoding faults 217 encoded |= (0xC0 << 40) 218 219 fau = FAUState(message = ins.message) 220 221 for i, (op, src) in enumerate(zip(operands, ins.srcs)): 222 parts = op.split('.') 223 encoded_src = encode_source(parts[0], fau) 224 225 # Require a word selection for special FAU values 226 needs_word_select = ((encoded_src >> 5) == 0b111) 227 228 # Has a swizzle been applied yet? 229 swizzled = False 230 231 for mod in parts[1:]: 232 # Encode the modifier 233 if mod in src.offset and src.bits[mod] == 1: 234 encoded |= (1 << src.offset[mod]) 235 elif src.halfswizzle and mod in enums[f'half_swizzles_{src.size}_bit'].bare_values: 236 die_if(swizzled, "Multiple swizzles specified") 237 swizzled = True 238 val = enums[f'half_swizzles_{src.size}_bit'].bare_values.index(mod) 239 encoded |= (val << src.offset['widen']) 240 elif mod in enums[f'swizzles_{src.size}_bit'].bare_values and (src.widen or src.lanes): 241 die_if(swizzled, "Multiple swizzles specified") 242 swizzled = True 243 val = enums[f'swizzles_{src.size}_bit'].bare_values.index(mod) 244 encoded |= (val << src.offset['widen']) 245 elif src.lane and mod in enums[f'lane_{src.size}_bit'].bare_values: 246 die_if(swizzled, "Multiple swizzles specified") 247 swizzled = True 248 val = enums[f'lane_{src.size}_bit'].bare_values.index(mod) 249 encoded |= (val << src.offset['lane']) 250 elif src.combine and mod in enums['combine'].bare_values: 251 die_if(swizzled, "Multiple swizzles specified") 252 swizzled = True 253 val = enums['combine'].bare_values.index(mod) 254 encoded |= (val << src.offset['combine']) 255 elif src.size == 32 and mod in enums['widen'].bare_values: 256 die_if(not src.swizzle, "Instruction doesn't take widens") 257 die_if(swizzled, "Multiple swizzles specified") 258 swizzled = True 259 val = enums['widen'].bare_values.index(mod) 260 encoded |= (val << src.offset['swizzle']) 261 elif src.size == 16 and mod in enums['swizzles_16_bit'].bare_values: 262 die_if(not src.swizzle, "Instruction doesn't take swizzles") 263 die_if(swizzled, "Multiple swizzles specified") 264 swizzled = True 265 val = enums['swizzles_16_bit'].bare_values.index(mod) 266 encoded |= (val << src.offset['swizzle']) 267 elif mod in enums['lane_8_bit'].bare_values: 268 die_if(not src.lane, "Instruction doesn't take a lane") 269 die_if(swizzled, "Multiple swizzles specified") 270 swizzled = True 271 val = enums['lane_8_bit'].bare_values.index(mod) 272 encoded |= (val << src.lane) 273 elif mod in enums['lanes_8_bit'].bare_values: 274 die_if(not src.lanes, "Instruction doesn't take a lane") 275 die_if(swizzled, "Multiple swizzles specified") 276 swizzled = True 277 val = enums['lanes_8_bit'].bare_values.index(mod) 278 encoded |= (val << src.offset['widen']) 279 elif mod in ['w0', 'w1']: 280 # Chck for special 281 die_if(not needs_word_select, 'Unexpected word select') 282 283 if mod == 'w1': 284 encoded_src |= 0x1 285 286 needs_word_select = False 287 else: 288 die(f"Unknown modifier {mod}") 289 290 # Encode the identity if a swizzle is required but not specified 291 if src.swizzle and not swizzled and src.size == 16: 292 mod = enums['swizzles_16_bit'].default 293 val = enums['swizzles_16_bit'].bare_values.index(mod) 294 encoded |= (val << src.offset['swizzle']) 295 elif src.widen and not swizzled and src.size == 16: 296 die_if(swizzled, "Multiple swizzles specified") 297 mod = enums['swizzles_16_bit'].default 298 val = enums['swizzles_16_bit'].bare_values.index(mod) 299 encoded |= (val << src.offset['widen']) 300 301 encoded |= encoded_src << src.start 302 fau.push(encoded_src) 303 304 operands = operands[len(ins.srcs):] 305 306 for i, (op, imm) in enumerate(zip(operands, ins.immediates)): 307 if op[0] == '#': 308 die_if(imm.name != 'constant', "Wrong syntax for immediate") 309 parts = [imm.name, op[1:]] 310 else: 311 parts = op.split(':') 312 die_if(len(parts) != 2, f"Wrong syntax for immediate, wrong number of colons in {op}") 313 die_if(parts[0] != imm.name, f"Wrong immediate, expected {imm.name}, got {parts[0]}") 314 315 if imm.signed: 316 minimum = -(1 << (imm.size - 1)) 317 maximum = +(1 << (imm.size - 1)) - 1 318 else: 319 minimum = 0 320 maximum = (1 << imm.size) - 1 321 322 val = parse_int(parts[1], minimum, maximum) 323 324 if val < 0: 325 # Sign extends 326 val = (1 << imm.size) + val 327 328 encoded |= (val << imm.start) 329 330 operands = operands[len(ins.immediates):] 331 332 # Encode the operation itself 333 encoded |= (ins.opcode << 48) 334 encoded |= (ins.opcode2 << ins.secondary_shift) 335 336 # Encode FAU page 337 if fau.page: 338 encoded |= (fau.page << 57) 339 340 # Encode modifiers 341 has_flow = False 342 for mod in mods: 343 if len(mod) == 0: 344 continue 345 346 if mod in enums['flow'].bare_values: 347 die_if(has_flow, "Multiple flow control modifiers specified") 348 has_flow = True 349 encoded |= (enums['flow'].bare_values.index(mod) << 59) 350 else: 351 candidates = [c for c in ins.modifiers if mod in c.bare_values] 352 353 die_if(len(candidates) == 0, f"Invalid modifier {mod} used") 354 assert(len(candidates) == 1) # No ambiguous modifiers 355 opts = candidates[0] 356 357 value = opts.bare_values.index(mod) 358 assert(value is not None) 359 360 die_if(opts.name in modifier_map, f"{opts.name} specified twice") 361 modifier_map[opts.name] = value 362 363 for mod in ins.modifiers: 364 value = modifier_map.get(mod.name, mod.default) 365 die_if(value is None, f"Missing required modifier {mod.name}") 366 367 assert(value < (1 << mod.size)) 368 encoded |= (value << mod.start) 369 370 return encoded 371 372if __name__ == "__main__": 373 # Provide commandline interface 374 parser = argparse.ArgumentParser(description='Assemble Valhall shaders') 375 parser.add_argument('infile', nargs='?', type=argparse.FileType('r'), 376 default=sys.stdin) 377 parser.add_argument('outfile', type=argparse.FileType('wb')) 378 args = parser.parse_args() 379 380 lines = args.infile.read().strip().split('\n') 381 lines = [l for l in lines if len(l) > 0 and l[0] != '#'] 382 383 packed = b''.join([struct.pack('<Q', parse_asm(ln)) for ln in lines]) 384 args.outfile.write(packed) 385