1/* 2 * Copyright © 2015 Connor Abbott 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 */ 24 25/** 26 * nir_opt_vectorize() aims to vectorize ALU instructions. 27 * 28 * The default vectorization width is 4. 29 * If desired, a callback function which returns the max vectorization width 30 * per instruction can be provided. 31 * 32 * The max vectorization width must be a power of 2. 33 */ 34 35#include "nir.h" 36#include "nir_vla.h" 37#include "nir_builder.h" 38#include "util/u_dynarray.h" 39 40#define HASH(hash, data) XXH32(&data, sizeof(data), hash) 41 42static uint32_t 43hash_src(uint32_t hash, const nir_src *src) 44{ 45 assert(src->is_ssa); 46 void *hash_data = nir_src_is_const(*src) ? NULL : src->ssa; 47 48 return HASH(hash, hash_data); 49} 50 51static uint32_t 52hash_alu_src(uint32_t hash, const nir_alu_src *src, 53 uint32_t num_components, uint32_t max_vec) 54{ 55 assert(!src->abs && !src->negate); 56 57 /* hash whether a swizzle accesses elements beyond the maximum 58 * vectorization factor: 59 * For example accesses to .x and .y are considered different variables 60 * compared to accesses to .z and .w for 16-bit vec2. 61 */ 62 uint32_t swizzle = (src->swizzle[0] & ~(max_vec - 1)); 63 hash = HASH(hash, swizzle); 64 65 return hash_src(hash, &src->src); 66} 67 68static uint32_t 69hash_instr(const void *data) 70{ 71 const nir_instr *instr = (nir_instr *) data; 72 assert(instr->type == nir_instr_type_alu); 73 nir_alu_instr *alu = nir_instr_as_alu(instr); 74 75 uint32_t hash = HASH(0, alu->op); 76 hash = HASH(hash, alu->dest.dest.ssa.bit_size); 77 78 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) 79 hash = hash_alu_src(hash, &alu->src[i], 80 alu->dest.dest.ssa.num_components, 81 instr->pass_flags); 82 83 return hash; 84} 85 86static bool 87srcs_equal(const nir_src *src1, const nir_src *src2) 88{ 89 assert(src1->is_ssa); 90 assert(src2->is_ssa); 91 92 return src1->ssa == src2->ssa || 93 (nir_src_is_const(*src1) && nir_src_is_const(*src2)); 94} 95 96static bool 97alu_srcs_equal(const nir_alu_src *src1, const nir_alu_src *src2, 98 uint32_t max_vec) 99{ 100 assert(!src1->abs); 101 assert(!src1->negate); 102 assert(!src2->abs); 103 assert(!src2->negate); 104 105 uint32_t mask = ~(max_vec - 1); 106 if ((src1->swizzle[0] & mask) != (src2->swizzle[0] & mask)) 107 return false; 108 109 return srcs_equal(&src1->src, &src2->src); 110} 111 112static bool 113instrs_equal(const void *data1, const void *data2) 114{ 115 const nir_instr *instr1 = (nir_instr *) data1; 116 const nir_instr *instr2 = (nir_instr *) data2; 117 assert(instr1->type == nir_instr_type_alu); 118 assert(instr2->type == nir_instr_type_alu); 119 120 nir_alu_instr *alu1 = nir_instr_as_alu(instr1); 121 nir_alu_instr *alu2 = nir_instr_as_alu(instr2); 122 123 if (alu1->op != alu2->op) 124 return false; 125 126 if (alu1->dest.dest.ssa.bit_size != alu2->dest.dest.ssa.bit_size) 127 return false; 128 129 for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) { 130 if (!alu_srcs_equal(&alu1->src[i], &alu2->src[i], instr1->pass_flags)) 131 return false; 132 } 133 134 return true; 135} 136 137static bool 138instr_can_rewrite(nir_instr *instr) 139{ 140 switch (instr->type) { 141 case nir_instr_type_alu: { 142 nir_alu_instr *alu = nir_instr_as_alu(instr); 143 144 /* Don't try and vectorize mov's. Either they'll be handled by copy 145 * prop, or they're actually necessary and trying to vectorize them 146 * would result in fighting with copy prop. 147 */ 148 if (alu->op == nir_op_mov) 149 return false; 150 151 /* no need to hash instructions which are already vectorized */ 152 if (alu->dest.dest.ssa.num_components >= instr->pass_flags) 153 return false; 154 155 if (nir_op_infos[alu->op].output_size != 0) 156 return false; 157 158 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { 159 if (nir_op_infos[alu->op].input_sizes[i] != 0) 160 return false; 161 162 /* don't hash instructions which are already swizzled 163 * outside of max_components: these should better be scalarized */ 164 uint32_t mask = ~(instr->pass_flags - 1); 165 for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) { 166 if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask)) 167 return false; 168 } 169 } 170 171 return true; 172 } 173 174 /* TODO support phi nodes */ 175 default: 176 break; 177 } 178 179 return false; 180} 181 182/* 183 * Tries to combine two instructions whose sources are different components of 184 * the same instructions into one vectorized instruction. Note that instr1 185 * should dominate instr2. 186 */ 187static nir_instr * 188instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2) 189{ 190 assert(instr1->type == nir_instr_type_alu); 191 assert(instr2->type == nir_instr_type_alu); 192 nir_alu_instr *alu1 = nir_instr_as_alu(instr1); 193 nir_alu_instr *alu2 = nir_instr_as_alu(instr2); 194 195 assert(alu1->dest.dest.ssa.bit_size == alu2->dest.dest.ssa.bit_size); 196 unsigned alu1_components = alu1->dest.dest.ssa.num_components; 197 unsigned alu2_components = alu2->dest.dest.ssa.num_components; 198 unsigned total_components = alu1_components + alu2_components; 199 200 assert(instr1->pass_flags == instr2->pass_flags); 201 if (total_components > instr1->pass_flags) 202 return NULL; 203 204 nir_builder b; 205 nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node)); 206 b.cursor = nir_after_instr(instr1); 207 208 nir_alu_instr *new_alu = nir_alu_instr_create(b.shader, alu1->op); 209 nir_ssa_dest_init(&new_alu->instr, &new_alu->dest.dest, 210 total_components, alu1->dest.dest.ssa.bit_size, NULL); 211 new_alu->dest.write_mask = (1 << total_components) - 1; 212 new_alu->instr.pass_flags = alu1->instr.pass_flags; 213 214 /* If either channel is exact, we have to preserve it even if it's 215 * not optimal for other channels. 216 */ 217 new_alu->exact = alu1->exact || alu2->exact; 218 219 /* If all channels don't wrap, we can say that the whole vector doesn't 220 * wrap. 221 */ 222 new_alu->no_signed_wrap = alu1->no_signed_wrap && alu2->no_signed_wrap; 223 new_alu->no_unsigned_wrap = alu1->no_unsigned_wrap && alu2->no_unsigned_wrap; 224 225 for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) { 226 /* handle constant merging case */ 227 if (alu1->src[i].src.ssa != alu2->src[i].src.ssa) { 228 nir_const_value *c1 = nir_src_as_const_value(alu1->src[i].src); 229 nir_const_value *c2 = nir_src_as_const_value(alu2->src[i].src); 230 assert(c1 && c2); 231 nir_const_value value[NIR_MAX_VEC_COMPONENTS]; 232 unsigned bit_size = alu1->src[i].src.ssa->bit_size; 233 234 for (unsigned j = 0; j < total_components; j++) { 235 value[j].u64 = j < alu1_components ? 236 c1[alu1->src[i].swizzle[j]].u64 : 237 c2[alu2->src[i].swizzle[j - alu1_components]].u64; 238 } 239 nir_ssa_def *def = nir_build_imm(&b, total_components, bit_size, value); 240 241 new_alu->src[i].src = nir_src_for_ssa(def); 242 for (unsigned j = 0; j < total_components; j++) 243 new_alu->src[i].swizzle[j] = j; 244 continue; 245 } 246 247 new_alu->src[i].src = alu1->src[i].src; 248 249 for (unsigned j = 0; j < alu1_components; j++) 250 new_alu->src[i].swizzle[j] = alu1->src[i].swizzle[j]; 251 252 for (unsigned j = 0; j < alu2_components; j++) { 253 new_alu->src[i].swizzle[j + alu1_components] = 254 alu2->src[i].swizzle[j]; 255 } 256 } 257 258 nir_builder_instr_insert(&b, &new_alu->instr); 259 260 /* update all ALU uses */ 261 nir_foreach_use_safe(src, &alu1->dest.dest.ssa) { 262 nir_instr *user_instr = src->parent_instr; 263 if (user_instr->type == nir_instr_type_alu) { 264 /* Check if user is found in the hashset */ 265 struct set_entry *entry = _mesa_set_search(instr_set, user_instr); 266 267 /* For ALU instructions, rewrite the source directly to avoid a 268 * round-trip through copy propagation. 269 */ 270 nir_instr_rewrite_src(user_instr, src, 271 nir_src_for_ssa(&new_alu->dest.dest.ssa)); 272 273 /* Rehash user if it was found in the hashset */ 274 if (entry && entry->key == user_instr) { 275 _mesa_set_remove(instr_set, entry); 276 _mesa_set_add(instr_set, user_instr); 277 } 278 } 279 } 280 281 nir_foreach_use_safe(src, &alu2->dest.dest.ssa) { 282 if (src->parent_instr->type == nir_instr_type_alu) { 283 /* For ALU instructions, rewrite the source directly to avoid a 284 * round-trip through copy propagation. 285 */ 286 nir_instr_rewrite_src(src->parent_instr, src, 287 nir_src_for_ssa(&new_alu->dest.dest.ssa)); 288 289 nir_alu_src *alu_src = container_of(src, nir_alu_src, src); 290 nir_alu_instr *use = nir_instr_as_alu(src->parent_instr); 291 unsigned components = nir_ssa_alu_instr_src_components(use, alu_src - use->src); 292 for (unsigned i = 0; i < components; i++) 293 alu_src->swizzle[i] += alu1_components; 294 } 295 } 296 297 /* update all other uses if there are any */ 298 unsigned swiz[NIR_MAX_VEC_COMPONENTS]; 299 300 if (!nir_ssa_def_is_unused(&alu1->dest.dest.ssa)) { 301 for (unsigned i = 0; i < alu1_components; i++) 302 swiz[i] = i; 303 nir_ssa_def *new_alu1 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz, 304 alu1_components); 305 nir_ssa_def_rewrite_uses(&alu1->dest.dest.ssa, new_alu1); 306 } 307 308 if (!nir_ssa_def_is_unused(&alu2->dest.dest.ssa)) { 309 for (unsigned i = 0; i < alu2_components; i++) 310 swiz[i] = i + alu1_components; 311 nir_ssa_def *new_alu2 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz, 312 alu2_components); 313 nir_ssa_def_rewrite_uses(&alu2->dest.dest.ssa, new_alu2); 314 } 315 316 nir_instr_remove(instr1); 317 nir_instr_remove(instr2); 318 319 return &new_alu->instr; 320} 321 322static struct set * 323vec_instr_set_create(void) 324{ 325 return _mesa_set_create(NULL, hash_instr, instrs_equal); 326} 327 328static void 329vec_instr_set_destroy(struct set *instr_set) 330{ 331 _mesa_set_destroy(instr_set, NULL); 332} 333 334static bool 335vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr, 336 nir_vectorize_cb filter, void *data) 337{ 338 /* set max vector to instr pass flags: this is used to hash swizzles */ 339 instr->pass_flags = filter ? filter(instr, data) : 4; 340 assert(util_is_power_of_two_or_zero(instr->pass_flags)); 341 342 if (!instr_can_rewrite(instr)) 343 return false; 344 345 struct set_entry *entry = _mesa_set_search(instr_set, instr); 346 if (entry) { 347 nir_instr *old_instr = (nir_instr *) entry->key; 348 _mesa_set_remove(instr_set, entry); 349 nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr); 350 if (new_instr) { 351 if (instr_can_rewrite(new_instr)) 352 _mesa_set_add(instr_set, new_instr); 353 return true; 354 } 355 } 356 357 _mesa_set_add(instr_set, instr); 358 return false; 359} 360 361static bool 362vectorize_block(nir_block *block, struct set *instr_set, 363 nir_vectorize_cb filter, void *data) 364{ 365 bool progress = false; 366 367 nir_foreach_instr_safe(instr, block) { 368 if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data)) 369 progress = true; 370 } 371 372 for (unsigned i = 0; i < block->num_dom_children; i++) { 373 nir_block *child = block->dom_children[i]; 374 progress |= vectorize_block(child, instr_set, filter, data); 375 } 376 377 nir_foreach_instr_reverse(instr, block) { 378 if (instr_can_rewrite(instr)) 379 _mesa_set_remove_key(instr_set, instr); 380 } 381 382 return progress; 383} 384 385static bool 386nir_opt_vectorize_impl(nir_function_impl *impl, 387 nir_vectorize_cb filter, void *data) 388{ 389 struct set *instr_set = vec_instr_set_create(); 390 391 nir_metadata_require(impl, nir_metadata_dominance); 392 393 bool progress = vectorize_block(nir_start_block(impl), instr_set, 394 filter, data); 395 396 if (progress) { 397 nir_metadata_preserve(impl, nir_metadata_block_index | 398 nir_metadata_dominance); 399 } else { 400 nir_metadata_preserve(impl, nir_metadata_all); 401 } 402 403 vec_instr_set_destroy(instr_set); 404 return progress; 405} 406 407bool 408nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter, 409 void *data) 410{ 411 bool progress = false; 412 413 nir_foreach_function(function, shader) { 414 if (function->impl) 415 progress |= nir_opt_vectorize_impl(function->impl, filter, data); 416 } 417 418 return progress; 419} 420