1/* 2 * Copyright © 2021 Valve Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 */ 24 25#include "ac_nir.h" 26#include "nir_builder.h" 27#include "u_math.h" 28#include "u_vector.h" 29 30enum { 31 nggc_passflag_used_by_pos = 1, 32 nggc_passflag_used_by_other = 2, 33 nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other, 34}; 35 36typedef struct 37{ 38 nir_ssa_def *ssa; 39 nir_variable *var; 40} saved_uniform; 41 42typedef struct 43{ 44 nir_variable *position_value_var; 45 nir_variable *prim_exp_arg_var; 46 nir_variable *es_accepted_var; 47 nir_variable *gs_accepted_var; 48 nir_variable *gs_vtx_indices_vars[3]; 49 50 struct u_vector saved_uniforms; 51 52 bool passthrough; 53 bool export_prim_id; 54 bool early_prim_export; 55 bool use_edgeflags; 56 bool has_prim_query; 57 bool can_cull; 58 unsigned wave_size; 59 unsigned max_num_waves; 60 unsigned num_vertices_per_primitives; 61 unsigned provoking_vtx_idx; 62 unsigned max_es_num_vertices; 63 unsigned total_lds_bytes; 64 65 uint64_t inputs_needed_by_pos; 66 uint64_t inputs_needed_by_others; 67 uint32_t instance_rate_inputs; 68 69 nir_instr *compact_arg_stores[4]; 70 nir_intrinsic_instr *overwrite_args; 71} lower_ngg_nogs_state; 72 73typedef struct 74{ 75 /* Bitmask of components used: 4 bits per slot, 1 bit per component. */ 76 uint8_t components_mask : 4; 77 /* output stream index */ 78 uint8_t stream : 2; 79} gs_output_info; 80 81typedef struct 82{ 83 nir_variable *output_vars[VARYING_SLOT_MAX][4]; 84 nir_variable *current_clear_primflag_idx_var; 85 int const_out_vtxcnt[4]; 86 int const_out_prmcnt[4]; 87 unsigned wave_size; 88 unsigned max_num_waves; 89 unsigned num_vertices_per_primitive; 90 unsigned lds_addr_gs_out_vtx; 91 unsigned lds_addr_gs_scratch; 92 unsigned lds_bytes_per_gs_out_vertex; 93 unsigned lds_offs_primflags; 94 bool found_out_vtxcnt[4]; 95 bool output_compile_time_known; 96 bool provoking_vertex_last; 97 gs_output_info output_info[VARYING_SLOT_MAX]; 98} lower_ngg_gs_state; 99 100/* LDS layout of Mesh Shader workgroup info. */ 101enum { 102 /* DW0: number of primitives */ 103 lds_ms_num_prims = 0, 104 /* DW1: reserved for future use */ 105 lds_ms_dw1_reserved = 4, 106 /* DW2: workgroup index within the current dispatch */ 107 lds_ms_wg_index = 8, 108 /* DW3: number of API workgroups in flight */ 109 lds_ms_num_api_waves = 12, 110}; 111 112/* Potential location for Mesh Shader outputs. */ 113typedef enum { 114 ms_out_mode_lds, 115 ms_out_mode_vram, 116 ms_out_mode_var, 117} ms_out_mode; 118 119typedef struct 120{ 121 uint64_t mask; /* Mask of output locations */ 122 uint32_t addr; /* Base address */ 123} ms_out_part; 124 125typedef struct 126{ 127 /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */ 128 struct { 129 uint32_t workgroup_info_addr; 130 ms_out_part vtx_attr; 131 ms_out_part prm_attr; 132 uint32_t indices_addr; 133 uint32_t total_size; 134 } lds; 135 /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. */ 136 struct { 137 ms_out_part vtx_attr; 138 ms_out_part prm_attr; 139 } vram; 140 /* Outputs without cross-invocation access can be stored in variables. */ 141 struct { 142 ms_out_part vtx_attr; 143 ms_out_part prm_attr; 144 } var; 145} ms_out_mem_layout; 146 147typedef struct 148{ 149 ms_out_mem_layout layout; 150 uint64_t per_vertex_outputs; 151 uint64_t per_primitive_outputs; 152 unsigned vertices_per_prim; 153 154 unsigned wave_size; 155 unsigned api_workgroup_size; 156 unsigned hw_workgroup_size; 157 158 nir_ssa_def *workgroup_index; 159 nir_variable *out_variables[VARYING_SLOT_MAX * 4]; 160 161 /* True if the lowering needs to insert the layer output. */ 162 bool insert_layer_output; 163 164 struct { 165 /* Bitmask of components used: 4 bits per slot, 1 bit per component. */ 166 uint32_t components_mask; 167 } output_info[VARYING_SLOT_MAX]; 168} lower_ngg_ms_state; 169 170typedef struct { 171 nir_variable *pre_cull_position_value_var; 172} remove_culling_shader_outputs_state; 173 174typedef struct { 175 nir_variable *pos_value_replacement; 176} remove_extra_position_output_state; 177 178/* Per-vertex LDS layout of culling shaders */ 179enum { 180 /* Position of the ES vertex (at the beginning for alignment reasons) */ 181 lds_es_pos_x = 0, 182 lds_es_pos_y = 4, 183 lds_es_pos_z = 8, 184 lds_es_pos_w = 12, 185 186 /* 1 when the vertex is accepted, 0 if it should be culled */ 187 lds_es_vertex_accepted = 16, 188 /* ID of the thread which will export the current thread's vertex */ 189 lds_es_exporter_tid = 17, 190 191 /* Repacked arguments - also listed separately for VS and TES */ 192 lds_es_arg_0 = 20, 193 194 /* VS arguments which need to be repacked */ 195 lds_es_vs_vertex_id = 20, 196 lds_es_vs_instance_id = 24, 197 198 /* TES arguments which need to be repacked */ 199 lds_es_tes_u = 20, 200 lds_es_tes_v = 24, 201 lds_es_tes_rel_patch_id = 28, 202 lds_es_tes_patch_id = 32, 203}; 204 205typedef struct { 206 nir_ssa_def *num_repacked_invocations; 207 nir_ssa_def *repacked_invocation_index; 208} wg_repack_result; 209 210/** 211 * Computes a horizontal sum of 8-bit packed values loaded from LDS. 212 * 213 * Each lane N will sum packed bytes 0 to N-1. 214 * We only care about the results from up to wave_id+1 lanes. 215 * (Other lanes are not deactivated but their calculation is not used.) 216 */ 217static nir_ssa_def * 218summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords) 219{ 220 /* We'll use shift to filter out the bytes not needed by the current lane. 221 * 222 * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes). 223 * However, two shifts are needed because one can't go all the way, 224 * so the shift amount is half that (and in bits). 225 * 226 * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes. 227 * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions, 228 * therefore v_dot can get rid of the unneeded values. 229 * This sequence is preferable because it better hides the latency of the LDS. 230 * 231 * If the v_dot instruction can't be used, we left-shift the packed bytes. 232 * This will shift out the unneeded bytes and shift in zeroes instead, 233 * then we sum them using v_sad_u8. 234 */ 235 236 nir_ssa_def *lane_id = nir_load_subgroup_invocation(b); 237 nir_ssa_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16); 238 bool use_dot = b->shader->options->has_udot_4x8; 239 240 if (num_lds_dwords == 1) { 241 nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift); 242 243 /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ 244 nir_ssa_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0)); 245 246 /* Horizontally add the packed bytes. */ 247 if (use_dot) { 248 return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0)); 249 } else { 250 nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift); 251 return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0)); 252 } 253 } else if (num_lds_dwords == 2) { 254 nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift); 255 256 /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ 257 nir_ssa_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); 258 nir_ssa_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); 259 260 /* Horizontally add the packed bytes. */ 261 if (use_dot) { 262 nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0)); 263 return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum); 264 } else { 265 nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift); 266 nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0)); 267 return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum); 268 } 269 } else { 270 unreachable("Unimplemented NGG wave count"); 271 } 272} 273 274/** 275 * Repacks invocations in the current workgroup to eliminate gaps between them. 276 * 277 * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave). 278 * Assumes that all invocations in the workgroup are active (exec = -1). 279 */ 280static wg_repack_result 281repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool, 282 unsigned lds_addr_base, unsigned max_num_waves, 283 unsigned wave_size) 284{ 285 /* Input boolean: 1 if the current invocation should survive the repack. */ 286 assert(input_bool->bit_size == 1); 287 288 /* STEP 1. Count surviving invocations in the current wave. 289 * 290 * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask. 291 */ 292 293 nir_ssa_def *input_mask = nir_ballot(b, 1, wave_size, input_bool); 294 nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask); 295 296 /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */ 297 if (max_num_waves == 1) { 298 wg_repack_result r = { 299 .num_repacked_invocations = surviving_invocations_in_current_wave, 300 .repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)), 301 }; 302 return r; 303 } 304 305 /* STEP 2. Waves tell each other their number of surviving invocations. 306 * 307 * Each wave activates only its first lane (exec = 1), which stores the number of surviving 308 * invocations in that wave into the LDS, then reads the numbers from every wave. 309 * 310 * The workgroup size of NGG shaders is at most 256, which means 311 * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode. 312 * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary. 313 */ 314 315 const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4); 316 assert(num_lds_dwords <= 2); 317 318 nir_ssa_def *wave_id = nir_load_subgroup_id(b); 319 nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32); 320 nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1)); 321 322 nir_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base); 323 324 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 325 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); 326 327 nir_ssa_def *packed_counts = nir_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u); 328 329 nir_pop_if(b, if_first_lane); 330 331 packed_counts = nir_if_phi(b, packed_counts, dont_care); 332 333 /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations. 334 * 335 * By now, every wave knows the number of surviving invocations in all waves. 336 * Each number is 1 byte, and they are packed into up to 2 dwords. 337 * 338 * Each lane N will sum the number of surviving invocations from waves 0 to N-1. 339 * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this. 340 * (Other lanes are not deactivated but their calculation is not used.) 341 * 342 * - We read the sum from the lane whose id is the current wave's id. 343 * Add the masked bitcount to this, and we get the repacked invocation index. 344 * - We read the sum from the lane whose id is the number of waves in the workgroup. 345 * This is the total number of surviving invocations in the workgroup. 346 */ 347 348 nir_ssa_def *num_waves = nir_load_num_subgroups(b); 349 nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords); 350 351 nir_ssa_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id); 352 nir_ssa_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves); 353 nir_ssa_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base); 354 355 wg_repack_result r = { 356 .num_repacked_invocations = wg_num_repacked_invocations, 357 .repacked_invocation_index = wg_repacked_index, 358 }; 359 360 return r; 361} 362 363static nir_ssa_def * 364pervertex_lds_addr(nir_builder *b, nir_ssa_def *vertex_idx, unsigned per_vtx_bytes) 365{ 366 return nir_imul_imm(b, vertex_idx, per_vtx_bytes); 367} 368 369static nir_ssa_def * 370emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives, 371 nir_ssa_def *vertex_indices[3], nir_ssa_def *is_null_prim, 372 bool use_edgeflags) 373{ 374 nir_ssa_def *arg = use_edgeflags 375 ? nir_load_initial_edgeflags_amd(b) 376 : nir_imm_int(b, 0); 377 378 for (unsigned i = 0; i < num_vertices_per_primitives; ++i) { 379 assert(vertex_indices[i]); 380 arg = nir_ior(b, arg, nir_ishl(b, vertex_indices[i], nir_imm_int(b, 10u * i))); 381 } 382 383 if (is_null_prim) { 384 if (is_null_prim->bit_size == 1) 385 is_null_prim = nir_b2i32(b, is_null_prim); 386 assert(is_null_prim->bit_size == 32); 387 arg = nir_ior(b, arg, nir_ishl(b, is_null_prim, nir_imm_int(b, 31u))); 388 } 389 390 return arg; 391} 392 393static void 394ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *st) 395{ 396 for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v) { 397 st->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr"); 398 399 nir_ssa_def *vtx = nir_ubfe(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u), 400 nir_imm_int(b, (v & 1u) * 16u), nir_imm_int(b, 16u)); 401 nir_store_var(b, st->gs_vtx_indices_vars[v], vtx, 0x1); 402 } 403} 404 405static nir_ssa_def * 406emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st) 407{ 408 if (st->passthrough) { 409 assert(!st->export_prim_id || b->shader->info.stage != MESA_SHADER_VERTEX); 410 return nir_load_packed_passthrough_primitive_amd(b); 411 } else { 412 nir_ssa_def *vtx_idx[3] = {0}; 413 414 for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v) 415 vtx_idx[v] = nir_load_var(b, st->gs_vtx_indices_vars[v]); 416 417 return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL, st->use_edgeflags); 418 } 419} 420 421static void 422emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg) 423{ 424 nir_ssa_def *gs_thread = st->gs_accepted_var 425 ? nir_load_var(b, st->gs_accepted_var) 426 : nir_has_input_primitive_amd(b); 427 428 nir_if *if_gs_thread = nir_push_if(b, gs_thread); 429 { 430 if (!arg) 431 arg = emit_ngg_nogs_prim_exp_arg(b, st); 432 433 if (st->has_prim_query) { 434 nir_if *if_shader_query = nir_push_if(b, nir_load_shader_query_enabled_amd(b)); 435 { 436 /* Number of active GS threads. Each has 1 output primitive. */ 437 nir_ssa_def *num_gs_threads = nir_bit_count(b, nir_ballot(b, 1, st->wave_size, nir_imm_bool(b, true))); 438 /* Activate only 1 lane and add the number of primitives to GDS. */ 439 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); 440 { 441 /* Use a different GDS offset than NGG GS to ensure that pipeline statistics 442 * queries won't return the number of primitives generated by VS/TES. 443 */ 444 nir_gds_atomic_add_amd(b, 32, num_gs_threads, nir_imm_int(b, 4), nir_imm_int(b, 0x100)); 445 } 446 nir_pop_if(b, if_elected); 447 } 448 nir_pop_if(b, if_shader_query); 449 } 450 451 nir_export_primitive_amd(b, arg); 452 } 453 nir_pop_if(b, if_gs_thread); 454} 455 456static void 457emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st) 458{ 459 nir_ssa_def *gs_thread = st->gs_accepted_var ? 460 nir_load_var(b, st->gs_accepted_var) : nir_has_input_primitive_amd(b); 461 462 nir_if *if_gs_thread = nir_push_if(b, gs_thread); 463 { 464 /* Copy Primitive IDs from GS threads to the LDS address 465 * corresponding to the ES thread of the provoking vertex. 466 * It will be exported as a per-vertex attribute. 467 */ 468 nir_ssa_def *prim_id = nir_load_primitive_id(b); 469 nir_ssa_def *provoking_vtx_idx = nir_load_var(b, st->gs_vtx_indices_vars[st->provoking_vtx_idx]); 470 nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u); 471 472 nir_store_shared(b, prim_id, addr); 473 } 474 nir_pop_if(b, if_gs_thread); 475} 476 477static void 478emit_store_ngg_nogs_es_primitive_id(nir_builder *b) 479{ 480 nir_ssa_def *prim_id = NULL; 481 482 if (b->shader->info.stage == MESA_SHADER_VERTEX) { 483 /* LDS address where the primitive ID is stored */ 484 nir_ssa_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b); 485 nir_ssa_def *addr = pervertex_lds_addr(b, thread_id_in_threadgroup, 4u); 486 487 /* Load primitive ID from LDS */ 488 prim_id = nir_load_shared(b, 1, 32, addr); 489 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) { 490 /* Just use tess eval primitive ID, which is the same as the patch ID. */ 491 prim_id = nir_load_primitive_id(b); 492 } 493 494 nir_io_semantics io_sem = { 495 .location = VARYING_SLOT_PRIMITIVE_ID, 496 .num_slots = 1, 497 }; 498 499 nir_store_output(b, prim_id, nir_imm_zero(b, 1, 32), 500 .base = io_sem.location, 501 .src_type = nir_type_uint32, .io_semantics = io_sem); 502} 503 504static bool 505remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state) 506{ 507 remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state; 508 509 if (instr->type != nir_instr_type_intrinsic) 510 return false; 511 512 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 513 514 /* These are not allowed in VS / TES */ 515 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output && 516 intrin->intrinsic != nir_intrinsic_load_per_vertex_input); 517 518 /* We are only interested in output stores now */ 519 if (intrin->intrinsic != nir_intrinsic_store_output) 520 return false; 521 522 b->cursor = nir_before_instr(instr); 523 524 /* Position output - store the value to a variable, remove output store */ 525 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); 526 if (io_sem.location == VARYING_SLOT_POS) { 527 /* TODO: check if it's indirect, etc? */ 528 unsigned writemask = nir_intrinsic_write_mask(intrin); 529 nir_ssa_def *store_val = intrin->src[0].ssa; 530 nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask); 531 } 532 533 /* Remove all output stores */ 534 nir_instr_remove(instr); 535 return true; 536} 537 538static void 539remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var) 540{ 541 remove_culling_shader_outputs_state s = { 542 .pre_cull_position_value_var = pre_cull_position_value_var, 543 }; 544 545 nir_shader_instructions_pass(culling_shader, remove_culling_shader_output, 546 nir_metadata_block_index | nir_metadata_dominance, &s); 547 548 /* Remove dead code resulting from the deleted outputs. */ 549 bool progress; 550 do { 551 progress = false; 552 NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars); 553 NIR_PASS(progress, culling_shader, nir_opt_dce); 554 NIR_PASS(progress, culling_shader, nir_opt_dead_cf); 555 } while (progress); 556} 557 558static void 559rewrite_uses_to_var(nir_builder *b, nir_ssa_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel) 560{ 561 if (old_def->parent_instr->type == nir_instr_type_load_const) 562 return; 563 564 b->cursor = nir_after_instr(old_def->parent_instr); 565 if (b->cursor.instr->type == nir_instr_type_phi) 566 b->cursor = nir_after_phis(old_def->parent_instr->block); 567 568 nir_ssa_def *pos_val_rep = nir_load_var(b, replacement_var); 569 nir_ssa_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel); 570 571 if (old_def->num_components > 1) { 572 /* old_def uses a swizzled vector component. 573 * There is no way to replace the uses of just a single vector component, 574 * so instead create a new vector and replace all uses of the old vector. 575 */ 576 nir_ssa_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0}; 577 for (unsigned j = 0; j < old_def->num_components; ++j) 578 old_def_elements[j] = nir_channel(b, old_def, j); 579 replacement = nir_vec(b, old_def_elements, old_def->num_components); 580 } 581 582 nir_ssa_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr); 583} 584 585static bool 586remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state) 587{ 588 remove_extra_position_output_state *s = (remove_extra_position_output_state *) state; 589 590 if (instr->type != nir_instr_type_intrinsic) 591 return false; 592 593 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 594 595 /* These are not allowed in VS / TES */ 596 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output && 597 intrin->intrinsic != nir_intrinsic_load_per_vertex_input); 598 599 /* We are only interested in output stores now */ 600 if (intrin->intrinsic != nir_intrinsic_store_output) 601 return false; 602 603 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); 604 if (io_sem.location != VARYING_SLOT_POS) 605 return false; 606 607 b->cursor = nir_before_instr(instr); 608 609 /* In case other outputs use what we calculated for pos, 610 * try to avoid calculating it again by rewriting the usages 611 * of the store components here. 612 */ 613 nir_ssa_def *store_val = intrin->src[0].ssa; 614 unsigned store_pos_component = nir_intrinsic_component(intrin); 615 616 nir_instr_remove(instr); 617 618 if (store_val->parent_instr->type == nir_instr_type_alu) { 619 nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr); 620 if (nir_op_is_vec(alu->op)) { 621 /* Output store uses a vector, we can easily rewrite uses of each vector element. */ 622 623 unsigned num_vec_src = 0; 624 if (alu->op == nir_op_mov) 625 num_vec_src = 1; 626 else if (alu->op == nir_op_vec2) 627 num_vec_src = 2; 628 else if (alu->op == nir_op_vec3) 629 num_vec_src = 3; 630 else if (alu->op == nir_op_vec4) 631 num_vec_src = 4; 632 assert(num_vec_src); 633 634 /* Remember the current components whose uses we wish to replace. 635 * This is needed because rewriting one source can affect the others too. 636 */ 637 nir_ssa_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0}; 638 for (unsigned i = 0; i < num_vec_src; i++) 639 vec_comps[i] = alu->src[i].src.ssa; 640 641 for (unsigned i = 0; i < num_vec_src; i++) 642 rewrite_uses_to_var(b, vec_comps[i], s->pos_value_replacement, store_pos_component + i); 643 } else { 644 rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component); 645 } 646 } else { 647 rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component); 648 } 649 650 return true; 651} 652 653static void 654remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state) 655{ 656 remove_extra_position_output_state s = { 657 .pos_value_replacement = nogs_state->position_value_var, 658 }; 659 660 nir_shader_instructions_pass(shader, remove_extra_pos_output, 661 nir_metadata_block_index | nir_metadata_dominance, &s); 662} 663 664static bool 665remove_compacted_arg(lower_ngg_nogs_state *state, nir_builder *b, unsigned idx) 666{ 667 nir_instr *store_instr = state->compact_arg_stores[idx]; 668 if (!store_instr) 669 return false; 670 671 /* Simply remove the store. */ 672 nir_instr_remove(store_instr); 673 674 /* Find the intrinsic that overwrites the shader arguments, 675 * and change its corresponding source. 676 * This will cause NIR's DCE to recognize the load and its phis as dead. 677 */ 678 b->cursor = nir_before_instr(&state->overwrite_args->instr); 679 nir_ssa_def *undef_arg = nir_ssa_undef(b, 1, 32); 680 nir_ssa_def_rewrite_uses(state->overwrite_args->src[idx].ssa, undef_arg); 681 682 state->compact_arg_stores[idx] = NULL; 683 return true; 684} 685 686static bool 687cleanup_culling_shader_after_dce(nir_shader *shader, 688 nir_function_impl *function_impl, 689 lower_ngg_nogs_state *state) 690{ 691 bool uses_vs_vertex_id = false; 692 bool uses_vs_instance_id = false; 693 bool uses_tes_u = false; 694 bool uses_tes_v = false; 695 bool uses_tes_rel_patch_id = false; 696 bool uses_tes_patch_id = false; 697 698 bool progress = false; 699 nir_builder b; 700 nir_builder_init(&b, function_impl); 701 702 nir_foreach_block_reverse_safe(block, function_impl) { 703 nir_foreach_instr_reverse_safe(instr, block) { 704 if (instr->type != nir_instr_type_intrinsic) 705 continue; 706 707 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 708 709 switch (intrin->intrinsic) { 710 case nir_intrinsic_alloc_vertices_and_primitives_amd: 711 goto cleanup_culling_shader_after_dce_done; 712 case nir_intrinsic_load_vertex_id: 713 case nir_intrinsic_load_vertex_id_zero_base: 714 uses_vs_vertex_id = true; 715 break; 716 case nir_intrinsic_load_instance_id: 717 uses_vs_instance_id = true; 718 break; 719 case nir_intrinsic_load_input: 720 if (state->instance_rate_inputs & 721 (1u << (nir_intrinsic_base(intrin) - VERT_ATTRIB_GENERIC0))) 722 uses_vs_instance_id = true; 723 else 724 uses_vs_vertex_id = true; 725 break; 726 case nir_intrinsic_load_tess_coord: 727 uses_tes_u = uses_tes_v = true; 728 break; 729 case nir_intrinsic_load_tess_rel_patch_id_amd: 730 uses_tes_rel_patch_id = true; 731 break; 732 case nir_intrinsic_load_primitive_id: 733 if (shader->info.stage == MESA_SHADER_TESS_EVAL) 734 uses_tes_patch_id = true; 735 break; 736 default: 737 break; 738 } 739 } 740 } 741 742 cleanup_culling_shader_after_dce_done: 743 744 if (shader->info.stage == MESA_SHADER_VERTEX) { 745 if (!uses_vs_vertex_id) 746 progress |= remove_compacted_arg(state, &b, 0); 747 if (!uses_vs_instance_id) 748 progress |= remove_compacted_arg(state, &b, 1); 749 } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) { 750 if (!uses_tes_u) 751 progress |= remove_compacted_arg(state, &b, 0); 752 if (!uses_tes_v) 753 progress |= remove_compacted_arg(state, &b, 1); 754 if (!uses_tes_rel_patch_id) 755 progress |= remove_compacted_arg(state, &b, 2); 756 if (!uses_tes_patch_id) 757 progress |= remove_compacted_arg(state, &b, 3); 758 } 759 760 return progress; 761} 762 763/** 764 * Perform vertex compaction after culling. 765 * 766 * 1. Repack surviving ES invocations (this determines which lane will export which vertex) 767 * 2. Surviving ES vertex invocations store their data to LDS 768 * 3. Emit GS_ALLOC_REQ 769 * 4. Repacked invocations load the vertex data from LDS 770 * 5. GS threads update their vertex indices 771 */ 772static void 773compact_vertices_after_culling(nir_builder *b, 774 lower_ngg_nogs_state *nogs_state, 775 nir_variable **repacked_arg_vars, 776 nir_variable **gs_vtxaddr_vars, 777 nir_ssa_def *invocation_index, 778 nir_ssa_def *es_vertex_lds_addr, 779 nir_ssa_def *es_exporter_tid, 780 nir_ssa_def *num_live_vertices_in_workgroup, 781 nir_ssa_def *fully_culled, 782 unsigned ngg_scratch_lds_base_addr, 783 unsigned pervertex_lds_bytes, 784 unsigned max_exported_args) 785{ 786 nir_variable *es_accepted_var = nogs_state->es_accepted_var; 787 nir_variable *gs_accepted_var = nogs_state->gs_accepted_var; 788 nir_variable *position_value_var = nogs_state->position_value_var; 789 nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var; 790 791 nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var)); 792 { 793 nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes); 794 795 /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */ 796 nir_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid); 797 798 /* Store the current thread's position output to the exporter thread's LDS space */ 799 nir_ssa_def *pos = nir_load_var(b, position_value_var); 800 nir_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x); 801 802 /* Store the current thread's repackable arguments to the exporter thread's LDS space */ 803 for (unsigned i = 0; i < max_exported_args; ++i) { 804 nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]); 805 nir_intrinsic_instr *store = nir_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i); 806 807 nogs_state->compact_arg_stores[i] = &store->instr; 808 } 809 } 810 nir_pop_if(b, if_es_accepted); 811 812 /* TODO: Consider adding a shortcut exit. 813 * Waves that have no vertices and primitives left can s_endpgm right here. 814 */ 815 816 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 817 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); 818 819 nir_ssa_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup); 820 nir_if *if_packed_es_thread = nir_push_if(b, es_survived); 821 { 822 /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */ 823 nir_ssa_def *exported_pos = nir_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x); 824 nir_store_var(b, position_value_var, exported_pos, 0xfu); 825 826 /* Read the repacked arguments */ 827 for (unsigned i = 0; i < max_exported_args; ++i) { 828 nir_ssa_def *arg_val = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i); 829 nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u); 830 } 831 } 832 nir_push_else(b, if_packed_es_thread); 833 { 834 nir_store_var(b, position_value_var, nir_ssa_undef(b, 4, 32), 0xfu); 835 for (unsigned i = 0; i < max_exported_args; ++i) 836 nir_store_var(b, repacked_arg_vars[i], nir_ssa_undef(b, 1, 32), 0x1u); 837 } 838 nir_pop_if(b, if_packed_es_thread); 839 840 nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var)); 841 { 842 nir_ssa_def *exporter_vtx_indices[3] = {0}; 843 844 /* Load the index of the ES threads that will export the current GS thread's vertices */ 845 for (unsigned v = 0; v < 3; ++v) { 846 nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]); 847 nir_ssa_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid); 848 exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx); 849 nir_store_var(b, nogs_state->gs_vtx_indices_vars[v], exporter_vtx_indices[v], 0x1); 850 } 851 852 nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL, nogs_state->use_edgeflags); 853 nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u); 854 } 855 nir_pop_if(b, if_gs_accepted); 856 857 nir_store_var(b, es_accepted_var, es_survived, 0x1u); 858 nir_store_var(b, gs_accepted_var, nir_bcsel(b, fully_culled, nir_imm_false(b), nir_has_input_primitive_amd(b)), 0x1u); 859} 860 861static void 862analyze_shader_before_culling_walk(nir_ssa_def *ssa, 863 uint8_t flag, 864 lower_ngg_nogs_state *nogs_state) 865{ 866 nir_instr *instr = ssa->parent_instr; 867 uint8_t old_pass_flags = instr->pass_flags; 868 instr->pass_flags |= flag; 869 870 if (instr->pass_flags == old_pass_flags) 871 return; /* Already visited. */ 872 873 switch (instr->type) { 874 case nir_instr_type_intrinsic: { 875 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 876 877 /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */ 878 switch (intrin->intrinsic) { 879 case nir_intrinsic_load_input: { 880 nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin); 881 uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location; 882 if (instr->pass_flags & nggc_passflag_used_by_pos) 883 nogs_state->inputs_needed_by_pos |= in_mask; 884 else if (instr->pass_flags & nggc_passflag_used_by_other) 885 nogs_state->inputs_needed_by_others |= in_mask; 886 break; 887 } 888 default: 889 break; 890 } 891 892 break; 893 } 894 case nir_instr_type_alu: { 895 nir_alu_instr *alu = nir_instr_as_alu(instr); 896 unsigned num_srcs = nir_op_infos[alu->op].num_inputs; 897 898 for (unsigned i = 0; i < num_srcs; ++i) { 899 analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, nogs_state); 900 } 901 902 break; 903 } 904 case nir_instr_type_phi: { 905 nir_phi_instr *phi = nir_instr_as_phi(instr); 906 nir_foreach_phi_src_safe(phi_src, phi) { 907 analyze_shader_before_culling_walk(phi_src->src.ssa, flag, nogs_state); 908 } 909 910 break; 911 } 912 default: 913 break; 914 } 915} 916 917static void 918analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_state) 919{ 920 nir_foreach_function(func, shader) { 921 nir_foreach_block(block, func->impl) { 922 nir_foreach_instr(instr, block) { 923 instr->pass_flags = 0; 924 925 if (instr->type != nir_instr_type_intrinsic) 926 continue; 927 928 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 929 if (intrin->intrinsic != nir_intrinsic_store_output) 930 continue; 931 932 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); 933 nir_ssa_def *store_val = intrin->src[0].ssa; 934 uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other; 935 analyze_shader_before_culling_walk(store_val, flag, nogs_state); 936 } 937 } 938 } 939} 940 941/** 942 * Save the reusable SSA definitions to variables so that the 943 * bottom shader part can reuse them from the top part. 944 * 945 * 1. We create a new function temporary variable for reusables, 946 * and insert a store+load. 947 * 2. The shader is cloned (the top part is created), then the 948 * control flow is reinserted (for the bottom part.) 949 * 3. For reusables, we delete the variable stores from the 950 * bottom part. This will make them use the variables from 951 * the top part and DCE the redundant instructions. 952 */ 953static void 954save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state) 955{ 956 ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, 4, sizeof(saved_uniform)); 957 assert(vec_ok); 958 959 nir_block *block = nir_start_block(b->impl); 960 while (block) { 961 /* Process the instructions in the current block. */ 962 nir_foreach_instr_safe(instr, block) { 963 /* Find instructions whose SSA definitions are used by both 964 * the top and bottom parts of the shader (before and after culling). 965 * Only in this case, it makes sense for the bottom part 966 * to try to reuse these from the top part. 967 */ 968 if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both) 969 continue; 970 971 /* Determine if we can reuse the current SSA value. 972 * When vertex compaction is used, it is possible that the same shader invocation 973 * processes a different vertex in the top and bottom part of the shader. 974 * Therefore, we only reuse uniform values. 975 */ 976 nir_ssa_def *ssa = NULL; 977 switch (instr->type) { 978 case nir_instr_type_alu: { 979 nir_alu_instr *alu = nir_instr_as_alu(instr); 980 if (alu->dest.dest.ssa.divergent) 981 continue; 982 /* Ignore uniform floats because they regress VGPR usage too much */ 983 if (nir_op_infos[alu->op].output_type & nir_type_float) 984 continue; 985 ssa = &alu->dest.dest.ssa; 986 break; 987 } 988 case nir_instr_type_intrinsic: { 989 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 990 if (!nir_intrinsic_can_reorder(intrin) || 991 !nir_intrinsic_infos[intrin->intrinsic].has_dest || 992 intrin->dest.ssa.divergent) 993 continue; 994 ssa = &intrin->dest.ssa; 995 break; 996 } 997 case nir_instr_type_phi: { 998 nir_phi_instr *phi = nir_instr_as_phi(instr); 999 if (phi->dest.ssa.divergent) 1000 continue; 1001 ssa = &phi->dest.ssa; 1002 break; 1003 } 1004 default: 1005 continue; 1006 } 1007 1008 assert(ssa); 1009 1010 /* Determine a suitable type for the SSA value. */ 1011 enum glsl_base_type base_type = GLSL_TYPE_UINT; 1012 switch (ssa->bit_size) { 1013 case 8: base_type = GLSL_TYPE_UINT8; break; 1014 case 16: base_type = GLSL_TYPE_UINT16; break; 1015 case 32: base_type = GLSL_TYPE_UINT; break; 1016 case 64: base_type = GLSL_TYPE_UINT64; break; 1017 default: continue; 1018 } 1019 1020 const struct glsl_type *t = ssa->num_components == 1 1021 ? glsl_scalar_type(base_type) 1022 : glsl_vector_type(base_type, ssa->num_components); 1023 1024 saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms); 1025 assert(saved); 1026 1027 /* Create a new NIR variable where we store the reusable value. 1028 * Then, we reload the variable and replace the uses of the value 1029 * with the reloaded variable. 1030 */ 1031 saved->var = nir_local_variable_create(b->impl, t, NULL); 1032 saved->ssa = ssa; 1033 1034 b->cursor = instr->type == nir_instr_type_phi 1035 ? nir_after_instr_and_phis(instr) 1036 : nir_after_instr(instr); 1037 nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components)); 1038 nir_ssa_def *reloaded = nir_load_var(b, saved->var); 1039 nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr); 1040 } 1041 1042 /* Look at the next CF node. */ 1043 nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node); 1044 if (next_cf_node) { 1045 /* It makes no sense to try to reuse things from within loops. */ 1046 bool next_is_loop = next_cf_node->type == nir_cf_node_loop; 1047 1048 /* Don't reuse if we're in divergent control flow. 1049 * 1050 * Thanks to vertex repacking, the same shader invocation may process a different vertex 1051 * in the top and bottom part, and it's even possible that this different vertex was initially 1052 * processed in a different wave. So the two parts may take a different divergent code path. 1053 * Therefore, these variables in divergent control flow may stay undefined. 1054 * 1055 * Note that this problem doesn't exist if vertices are not repacked or if the 1056 * workgroup only has a single wave. 1057 */ 1058 bool next_is_divergent_if = 1059 next_cf_node->type == nir_cf_node_if && 1060 nir_cf_node_as_if(next_cf_node)->condition.ssa->divergent; 1061 1062 if (next_is_loop || next_is_divergent_if) { 1063 block = nir_cf_node_cf_tree_next(next_cf_node); 1064 continue; 1065 } 1066 } 1067 1068 /* Go to the next block. */ 1069 block = nir_block_cf_tree_next(block); 1070 } 1071} 1072 1073/** 1074 * Reuses suitable variables from the top part of the shader, 1075 * by deleting their stores from the bottom part. 1076 */ 1077static void 1078apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state) 1079{ 1080 if (!u_vector_length(&nogs_state->saved_uniforms)) { 1081 u_vector_finish(&nogs_state->saved_uniforms); 1082 return; 1083 } 1084 1085 nir_foreach_block_reverse_safe(block, b->impl) { 1086 nir_foreach_instr_reverse_safe(instr, block) { 1087 if (instr->type != nir_instr_type_intrinsic) 1088 continue; 1089 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 1090 1091 /* When we found any of these intrinsics, it means 1092 * we reached the top part and we must stop. 1093 */ 1094 if (intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd) 1095 goto done; 1096 1097 if (intrin->intrinsic != nir_intrinsic_store_deref) 1098 continue; 1099 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); 1100 if (deref->deref_type != nir_deref_type_var) 1101 continue; 1102 1103 saved_uniform *saved; 1104 u_vector_foreach(saved, &nogs_state->saved_uniforms) { 1105 if (saved->var == deref->var) { 1106 nir_instr_remove(instr); 1107 } 1108 } 1109 } 1110 } 1111 1112 done: 1113 u_vector_finish(&nogs_state->saved_uniforms); 1114} 1115 1116static void 1117add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state) 1118{ 1119 bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID); 1120 bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID); 1121 1122 unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4; 1123 if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id) 1124 max_exported_args--; 1125 else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id) 1126 max_exported_args--; 1127 1128 unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u; 1129 unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices; 1130 unsigned max_num_waves = nogs_state->max_num_waves; 1131 unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u); 1132 unsigned ngg_scratch_lds_bytes = ALIGN(max_num_waves, 4u); 1133 nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes; 1134 1135 nir_function_impl *impl = nir_shader_get_entrypoint(b->shader); 1136 1137 /* Create some helper variables. */ 1138 nir_variable *position_value_var = nogs_state->position_value_var; 1139 nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var; 1140 nir_variable *gs_accepted_var = nogs_state->gs_accepted_var; 1141 nir_variable *es_accepted_var = nogs_state->es_accepted_var; 1142 nir_variable *gs_vtxaddr_vars[3] = { 1143 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"), 1144 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"), 1145 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"), 1146 }; 1147 nir_variable *repacked_arg_vars[4] = { 1148 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"), 1149 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"), 1150 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"), 1151 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"), 1152 }; 1153 1154 /* Top part of the culling shader (aka. position shader part) 1155 * 1156 * We clone the full ES shader and emit it here, but we only really care 1157 * about its position output, so we delete every other output from this part. 1158 * The position output is stored into a temporary variable, and reloaded later. 1159 */ 1160 1161 b->cursor = nir_before_cf_list(&impl->body); 1162 1163 nir_ssa_def *es_thread = nir_has_input_vertex_amd(b); 1164 nir_if *if_es_thread = nir_push_if(b, es_thread); 1165 { 1166 /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output. 1167 * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that. 1168 */ 1169 nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu); 1170 1171 /* Now reinsert a clone of the shader code */ 1172 struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL); 1173 nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table); 1174 _mesa_hash_table_destroy(remap_table, NULL); 1175 b->cursor = nir_after_cf_list(&if_es_thread->then_list); 1176 1177 /* Remember the current thread's shader arguments */ 1178 if (b->shader->info.stage == MESA_SHADER_VERTEX) { 1179 nir_store_var(b, repacked_arg_vars[0], nir_load_vertex_id_zero_base(b), 0x1u); 1180 if (uses_instance_id) 1181 nir_store_var(b, repacked_arg_vars[1], nir_load_instance_id(b), 0x1u); 1182 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) { 1183 nir_ssa_def *tess_coord = nir_load_tess_coord(b); 1184 nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u); 1185 nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u); 1186 nir_store_var(b, repacked_arg_vars[2], nir_load_tess_rel_patch_id_amd(b), 0x1u); 1187 if (uses_tess_primitive_id) 1188 nir_store_var(b, repacked_arg_vars[3], nir_load_primitive_id(b), 0x1u); 1189 } else { 1190 unreachable("Should be VS or TES."); 1191 } 1192 } 1193 nir_pop_if(b, if_es_thread); 1194 1195 nir_store_var(b, es_accepted_var, es_thread, 0x1u); 1196 nir_store_var(b, gs_accepted_var, nir_has_input_primitive_amd(b), 0x1u); 1197 1198 /* Remove all non-position outputs, and put the position output into the variable. */ 1199 nir_metadata_preserve(impl, nir_metadata_none); 1200 remove_culling_shader_outputs(b->shader, nogs_state, position_value_var); 1201 b->cursor = nir_after_cf_list(&impl->body); 1202 1203 /* Run culling algorithms if culling is enabled. 1204 * 1205 * NGG culling can be enabled or disabled in runtime. 1206 * This is determined by a SGPR shader argument which is acccessed 1207 * by the following NIR intrinsic. 1208 */ 1209 1210 nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b)); 1211 { 1212 nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); 1213 nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes); 1214 1215 /* ES invocations store their vertex data to LDS for GS threads to read. */ 1216 if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b)); 1217 { 1218 /* Store position components that are relevant to culling in LDS */ 1219 nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var); 1220 nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3); 1221 nir_store_shared(b, pre_cull_w, es_vertex_lds_addr, .base = lds_es_pos_w); 1222 nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w); 1223 nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w); 1224 nir_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .base = lds_es_pos_x); 1225 1226 /* Clear out the ES accepted flag in LDS */ 1227 nir_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .align_mul = 4, .base = lds_es_vertex_accepted); 1228 } 1229 nir_pop_if(b, if_es_thread); 1230 1231 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 1232 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); 1233 1234 nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u); 1235 nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u); 1236 1237 /* GS invocations load the vertex data and perform the culling. */ 1238 nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b)); 1239 { 1240 /* Load vertex indices from input VGPRs */ 1241 nir_ssa_def *vtx_idx[3] = {0}; 1242 for (unsigned vertex = 0; vertex < 3; ++vertex) 1243 vtx_idx[vertex] = nir_load_var(b, nogs_state->gs_vtx_indices_vars[vertex]); 1244 1245 nir_ssa_def *vtx_addr[3] = {0}; 1246 nir_ssa_def *pos[3][4] = {0}; 1247 1248 /* Load W positions of vertices first because the culling code will use these first */ 1249 for (unsigned vtx = 0; vtx < 3; ++vtx) { 1250 vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes); 1251 pos[vtx][3] = nir_load_shared(b, 1, 32, vtx_addr[vtx], .base = lds_es_pos_w); 1252 nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u); 1253 } 1254 1255 /* Load the X/W, Y/W positions of vertices */ 1256 for (unsigned vtx = 0; vtx < 3; ++vtx) { 1257 nir_ssa_def *xy = nir_load_shared(b, 2, 32, vtx_addr[vtx], .base = lds_es_pos_x); 1258 pos[vtx][0] = nir_channel(b, xy, 0); 1259 pos[vtx][1] = nir_channel(b, xy, 1); 1260 } 1261 1262 /* See if the current primitive is accepted */ 1263 nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos); 1264 nir_store_var(b, gs_accepted_var, accepted, 0x1u); 1265 1266 nir_if *if_gs_accepted = nir_push_if(b, accepted); 1267 { 1268 /* Store the accepted state to LDS for ES threads */ 1269 for (unsigned vtx = 0; vtx < 3; ++vtx) 1270 nir_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u); 1271 } 1272 nir_pop_if(b, if_gs_accepted); 1273 } 1274 nir_pop_if(b, if_gs_thread); 1275 1276 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 1277 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); 1278 1279 nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u); 1280 1281 /* ES invocations load their accepted flag from LDS. */ 1282 if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b)); 1283 { 1284 nir_ssa_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u); 1285 nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8)); 1286 nir_store_var(b, es_accepted_var, accepted_bool, 0x1u); 1287 } 1288 nir_pop_if(b, if_es_thread); 1289 1290 nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var); 1291 1292 /* Repack the vertices that survived the culling. */ 1293 wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr, 1294 nogs_state->max_num_waves, nogs_state->wave_size); 1295 nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations; 1296 nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index; 1297 1298 /* If all vertices are culled, set primitive count to 0 as well. */ 1299 nir_ssa_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b); 1300 nir_ssa_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u); 1301 num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), num_exported_prims); 1302 1303 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0))); 1304 { 1305 /* Tell the final vertex and primitive count to the HW. */ 1306 nir_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims); 1307 } 1308 nir_pop_if(b, if_wave_0); 1309 1310 /* Vertex compaction. */ 1311 compact_vertices_after_culling(b, nogs_state, 1312 repacked_arg_vars, gs_vtxaddr_vars, 1313 invocation_index, es_vertex_lds_addr, 1314 es_exporter_tid, num_live_vertices_in_workgroup, fully_culled, 1315 ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args); 1316 } 1317 nir_push_else(b, if_cull_en); 1318 { 1319 /* When culling is disabled, we do the same as we would without culling. */ 1320 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0))); 1321 { 1322 nir_ssa_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b); 1323 nir_ssa_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b); 1324 nir_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt); 1325 } 1326 nir_pop_if(b, if_wave_0); 1327 nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u); 1328 } 1329 nir_pop_if(b, if_cull_en); 1330 1331 /* Update shader arguments. 1332 * 1333 * The registers which hold information about the subgroup's 1334 * vertices and primitives are updated here, so the rest of the shader 1335 * doesn't need to worry about the culling. 1336 * 1337 * These "overwrite" intrinsics must be at top level control flow, 1338 * otherwise they can mess up the backend (eg. ACO's SSA). 1339 * 1340 * TODO: 1341 * A cleaner solution would be to simply replace all usages of these args 1342 * with the load of the variables. 1343 * However, this wouldn't work right now because the backend uses the arguments 1344 * for purposes not expressed in NIR, eg. VS input loads, etc. 1345 * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd. 1346 */ 1347 1348 if (b->shader->info.stage == MESA_SHADER_VERTEX) 1349 nogs_state->overwrite_args = 1350 nir_overwrite_vs_arguments_amd(b, 1351 nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1])); 1352 else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) 1353 nogs_state->overwrite_args = 1354 nir_overwrite_tes_arguments_amd(b, 1355 nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]), 1356 nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3])); 1357 else 1358 unreachable("Should be VS or TES."); 1359} 1360 1361void 1362ac_nir_lower_ngg_nogs(nir_shader *shader, 1363 enum radeon_family family, 1364 unsigned max_num_es_vertices, 1365 unsigned num_vertices_per_primitives, 1366 unsigned max_workgroup_size, 1367 unsigned wave_size, 1368 bool can_cull, 1369 bool early_prim_export, 1370 bool passthrough, 1371 bool export_prim_id, 1372 bool provoking_vtx_last, 1373 bool use_edgeflags, 1374 bool has_prim_query, 1375 uint32_t instance_rate_inputs) 1376{ 1377 nir_function_impl *impl = nir_shader_get_entrypoint(shader); 1378 assert(impl); 1379 assert(max_num_es_vertices && max_workgroup_size && wave_size); 1380 assert(!(can_cull && passthrough)); 1381 1382 nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value"); 1383 nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg"); 1384 nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL; 1385 nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL; 1386 1387 lower_ngg_nogs_state state = { 1388 .passthrough = passthrough, 1389 .export_prim_id = export_prim_id, 1390 .early_prim_export = early_prim_export, 1391 .use_edgeflags = use_edgeflags, 1392 .has_prim_query = has_prim_query, 1393 .can_cull = can_cull, 1394 .num_vertices_per_primitives = num_vertices_per_primitives, 1395 .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0, 1396 .position_value_var = position_value_var, 1397 .prim_exp_arg_var = prim_exp_arg_var, 1398 .es_accepted_var = es_accepted_var, 1399 .gs_accepted_var = gs_accepted_var, 1400 .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size), 1401 .max_es_num_vertices = max_num_es_vertices, 1402 .wave_size = wave_size, 1403 .instance_rate_inputs = instance_rate_inputs, 1404 }; 1405 1406 const bool need_prim_id_store_shared = 1407 export_prim_id && shader->info.stage == MESA_SHADER_VERTEX; 1408 1409 if (export_prim_id) { 1410 nir_variable *prim_id_var = nir_variable_create(shader, nir_var_shader_out, glsl_uint_type(), "ngg_prim_id"); 1411 prim_id_var->data.location = VARYING_SLOT_PRIMITIVE_ID; 1412 prim_id_var->data.driver_location = VARYING_SLOT_PRIMITIVE_ID; 1413 prim_id_var->data.interpolation = INTERP_MODE_NONE; 1414 shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID; 1415 } 1416 1417 nir_builder builder; 1418 nir_builder *b = &builder; /* This is to avoid the & */ 1419 nir_builder_init(b, impl); 1420 1421 if (can_cull) { 1422 /* We need divergence info for culling shaders. */ 1423 nir_divergence_analysis(shader); 1424 analyze_shader_before_culling(shader, &state); 1425 save_reusable_variables(b, &state); 1426 } 1427 1428 nir_cf_list extracted; 1429 nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body)); 1430 b->cursor = nir_before_cf_list(&impl->body); 1431 1432 ngg_nogs_init_vertex_indices_vars(b, impl, &state); 1433 1434 if (!can_cull) { 1435 /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */ 1436 if (!(passthrough && family >= CHIP_NAVI23)) { 1437 /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */ 1438 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0))); 1439 { 1440 nir_ssa_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b); 1441 nir_ssa_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b); 1442 nir_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt); 1443 } 1444 nir_pop_if(b, if_wave_0); 1445 } 1446 1447 /* Take care of early primitive export, otherwise just pack the primitive export argument */ 1448 if (state.early_prim_export) 1449 emit_ngg_nogs_prim_export(b, &state, NULL); 1450 else 1451 nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u); 1452 } else { 1453 add_deferred_attribute_culling(b, &extracted, &state); 1454 b->cursor = nir_after_cf_list(&impl->body); 1455 1456 if (state.early_prim_export) 1457 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var)); 1458 } 1459 1460 if (need_prim_id_store_shared) { 1461 /* We need LDS space when VS needs to export the primitive ID. */ 1462 state.total_lds_bytes = MAX2(state.total_lds_bytes, max_num_es_vertices * 4u); 1463 1464 /* The LDS space aliases with what is used by culling, so we need a barrier. */ 1465 if (can_cull) { 1466 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, 1467 .memory_scope = NIR_SCOPE_WORKGROUP, 1468 .memory_semantics = NIR_MEMORY_ACQ_REL, 1469 .memory_modes = nir_var_mem_shared); 1470 } 1471 1472 emit_ngg_nogs_prim_id_store_shared(b, &state); 1473 1474 /* Wait for GS threads to store primitive ID in LDS. */ 1475 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, .memory_scope = NIR_SCOPE_WORKGROUP, 1476 .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared); 1477 } 1478 1479 nir_intrinsic_instr *export_vertex_instr; 1480 nir_ssa_def *es_thread = can_cull ? nir_load_var(b, es_accepted_var) : nir_has_input_vertex_amd(b); 1481 1482 nir_if *if_es_thread = nir_push_if(b, es_thread); 1483 { 1484 /* Run the actual shader */ 1485 nir_cf_reinsert(&extracted, b->cursor); 1486 b->cursor = nir_after_cf_list(&if_es_thread->then_list); 1487 1488 if (state.export_prim_id) 1489 emit_store_ngg_nogs_es_primitive_id(b); 1490 1491 /* Export all vertex attributes (including the primitive ID) */ 1492 export_vertex_instr = nir_export_vertex_amd(b); 1493 } 1494 nir_pop_if(b, if_es_thread); 1495 1496 /* Take care of late primitive export */ 1497 if (!state.early_prim_export) { 1498 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var)); 1499 } 1500 1501 if (can_cull) { 1502 /* Replace uniforms. */ 1503 apply_reusable_variables(b, &state); 1504 1505 /* Remove the redundant position output. */ 1506 remove_extra_pos_outputs(shader, &state); 1507 1508 /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3, 1509 * it seems that it's best to put the position export always at the end, and 1510 * then let ACO schedule it up (slightly) only when early prim export is used. 1511 */ 1512 b->cursor = nir_before_instr(&export_vertex_instr->instr); 1513 1514 nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var); 1515 nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 }; 1516 nir_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem); 1517 } 1518 1519 nir_metadata_preserve(impl, nir_metadata_none); 1520 nir_validate_shader(shader, "after emitting NGG VS/TES"); 1521 1522 /* Cleanup */ 1523 nir_opt_dead_write_vars(shader); 1524 nir_lower_vars_to_ssa(shader); 1525 nir_remove_dead_variables(shader, nir_var_function_temp, NULL); 1526 nir_lower_alu_to_scalar(shader, NULL, NULL); 1527 nir_lower_phis_to_scalar(shader, true); 1528 1529 if (can_cull) { 1530 /* It's beneficial to redo these opts after splitting the shader. */ 1531 nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies); 1532 nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef); 1533 } 1534 1535 bool progress; 1536 do { 1537 progress = false; 1538 NIR_PASS(progress, shader, nir_opt_undef); 1539 NIR_PASS(progress, shader, nir_opt_dce); 1540 NIR_PASS(progress, shader, nir_opt_dead_cf); 1541 1542 if (can_cull) 1543 progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state); 1544 } while (progress); 1545 1546 shader->info.shared_size = state.total_lds_bytes; 1547} 1548 1549/** 1550 * Return the address of the LDS storage reserved for the N'th vertex, 1551 * where N is in emit order, meaning: 1552 * - during the finale, N is the invocation_index (within the workgroup) 1553 * - during vertex emit, i.e. while the API GS shader invocation is running, 1554 * N = invocation_index * gs_max_out_vertices + emit_idx 1555 * where emit_idx is the vertex index in the current API GS invocation. 1556 * 1557 * Goals of the LDS memory layout: 1558 * 1. Eliminate bank conflicts on write for geometry shaders that have all emits 1559 * in uniform control flow 1560 * 2. Eliminate bank conflicts on read for export if, additionally, there is no 1561 * culling 1562 * 3. Agnostic to the number of waves (since we don't know it before compiling) 1563 * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.) 1564 * 5. Avoid wasting memory. 1565 * 1566 * We use an AoS layout due to point 4 (this also helps point 3). In an AoS 1567 * layout, elimination of bank conflicts requires that each vertex occupy an 1568 * odd number of dwords. We use the additional dword to store the output stream 1569 * index as well as a flag to indicate whether this vertex ends a primitive 1570 * for rasterization. 1571 * 1572 * Swizzling is required to satisfy points 1 and 2 simultaneously. 1573 * 1574 * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx). 1575 * Indices are swizzled in groups of 32, which ensures point 1 without 1576 * disturbing point 2. 1577 * 1578 * \return an LDS pointer to type {[N x i32], [4 x i8]} 1579 */ 1580static nir_ssa_def * 1581ngg_gs_out_vertex_addr(nir_builder *b, nir_ssa_def *out_vtx_idx, lower_ngg_gs_state *s) 1582{ 1583 unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1; 1584 1585 /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */ 1586 if (write_stride_2exp) { 1587 nir_ssa_def *row = nir_ushr_imm(b, out_vtx_idx, 5); 1588 nir_ssa_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u); 1589 out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle); 1590 } 1591 1592 nir_ssa_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex); 1593 return nir_iadd_imm_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx); 1594} 1595 1596static nir_ssa_def * 1597ngg_gs_emit_vertex_addr(nir_builder *b, nir_ssa_def *gs_vtx_idx, lower_ngg_gs_state *s) 1598{ 1599 nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b); 1600 nir_ssa_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out); 1601 nir_ssa_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx); 1602 1603 return ngg_gs_out_vertex_addr(b, out_vtx_idx, s); 1604} 1605 1606static void 1607ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned stream, lower_ngg_gs_state *s) 1608{ 1609 nir_ssa_def *zero_u8 = nir_imm_zero(b, 1, 8); 1610 nir_store_var(b, s->current_clear_primflag_idx_var, num_vertices, 0x1u); 1611 1612 nir_loop *loop = nir_push_loop(b); 1613 { 1614 nir_ssa_def *current_clear_primflag_idx = nir_load_var(b, s->current_clear_primflag_idx_var); 1615 nir_if *if_break = nir_push_if(b, nir_uge(b, current_clear_primflag_idx, nir_imm_int(b, b->shader->info.gs.vertices_out))); 1616 { 1617 nir_jump(b, nir_jump_break); 1618 } 1619 nir_push_else(b, if_break); 1620 { 1621 nir_ssa_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, current_clear_primflag_idx, s); 1622 nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream); 1623 nir_store_var(b, s->current_clear_primflag_idx_var, nir_iadd_imm_nuw(b, current_clear_primflag_idx, 1), 0x1u); 1624 } 1625 nir_pop_if(b, if_break); 1626 } 1627 nir_pop_loop(b, loop); 1628} 1629 1630static void 1631ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s) 1632{ 1633 nir_if *if_shader_query = nir_push_if(b, nir_load_shader_query_enabled_amd(b)); 1634 nir_ssa_def *num_prims_in_wave = NULL; 1635 1636 /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives. 1637 * GS emits points, line strips or triangle strips. 1638 * Real primitives are points, lines or triangles. 1639 */ 1640 if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) { 1641 unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]); 1642 unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]); 1643 unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u); 1644 nir_ssa_def *num_threads = nir_bit_count(b, nir_ballot(b, 1, s->wave_size, nir_imm_bool(b, true))); 1645 num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt); 1646 } else { 1647 nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa; 1648 nir_ssa_def *prm_cnt = intrin->src[1].ssa; 1649 if (s->num_vertices_per_primitive > 1) 1650 prm_cnt = nir_iadd(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt); 1651 num_prims_in_wave = nir_reduce(b, prm_cnt, .reduction_op = nir_op_iadd); 1652 } 1653 1654 /* Store the query result to GDS using an atomic add. */ 1655 nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1)); 1656 nir_gds_atomic_add_amd(b, 32, num_prims_in_wave, nir_imm_int(b, 0), nir_imm_int(b, 0x100)); 1657 nir_pop_if(b, if_first_lane); 1658 1659 nir_pop_if(b, if_shader_query); 1660} 1661 1662static bool 1663lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s) 1664{ 1665 assert(nir_src_is_const(intrin->src[1])); 1666 b->cursor = nir_before_instr(&intrin->instr); 1667 1668 unsigned writemask = nir_intrinsic_write_mask(intrin); 1669 unsigned base = nir_intrinsic_base(intrin); 1670 unsigned component_offset = nir_intrinsic_component(intrin); 1671 unsigned base_offset = nir_src_as_uint(intrin->src[1]); 1672 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); 1673 1674 assert((base + base_offset) < VARYING_SLOT_MAX); 1675 1676 nir_ssa_def *store_val = intrin->src[0].ssa; 1677 1678 for (unsigned comp = 0; comp < 4; ++comp) { 1679 if (!(writemask & (1 << comp))) 1680 continue; 1681 unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3; 1682 if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) 1683 continue; 1684 1685 /* Small bitsize components consume the same amount of space as 32-bit components, 1686 * but 64-bit ones consume twice as many. (Vulkan spec 15.1.5) 1687 */ 1688 unsigned num_consumed_components = DIV_ROUND_UP(store_val->bit_size, 32); 1689 nir_ssa_def *element = nir_channel(b, store_val, comp); 1690 if (num_consumed_components > 1) 1691 element = nir_extract_bits(b, &element, 1, 0, num_consumed_components, 32); 1692 1693 /* Save output usage info. */ 1694 gs_output_info *info = &s->output_info[io_sem.location]; 1695 /* The same output should always belong to the same stream. */ 1696 assert(!info->components_mask || info->stream == stream); 1697 info->stream = stream; 1698 info->components_mask |= BITFIELD_BIT(component_offset + comp * num_consumed_components); 1699 1700 for (unsigned c = 0; c < num_consumed_components; ++c) { 1701 unsigned component_index = (comp * num_consumed_components) + c + component_offset; 1702 unsigned base_index = base + base_offset + component_index / 4; 1703 component_index %= 4; 1704 1705 /* Store the current component element */ 1706 nir_ssa_def *component_element = element; 1707 if (num_consumed_components > 1) 1708 component_element = nir_channel(b, component_element, c); 1709 if (component_element->bit_size != 32) 1710 component_element = nir_u2u32(b, component_element); 1711 1712 nir_store_var(b, s->output_vars[base_index][component_index], component_element, 0x1u); 1713 } 1714 } 1715 1716 nir_instr_remove(&intrin->instr); 1717 return true; 1718} 1719 1720static bool 1721lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s) 1722{ 1723 b->cursor = nir_before_instr(&intrin->instr); 1724 1725 unsigned stream = nir_intrinsic_stream_id(intrin); 1726 if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) { 1727 nir_instr_remove(&intrin->instr); 1728 return true; 1729 } 1730 1731 nir_ssa_def *gs_emit_vtx_idx = intrin->src[0].ssa; 1732 nir_ssa_def *current_vtx_per_prim = intrin->src[1].ssa; 1733 nir_ssa_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s); 1734 1735 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) { 1736 unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot))); 1737 gs_output_info *info = &s->output_info[slot]; 1738 if (info->stream != stream || !info->components_mask) 1739 continue; 1740 1741 unsigned mask = info->components_mask; 1742 while (mask) { 1743 int start, count; 1744 u_bit_scan_consecutive_range(&mask, &start, &count); 1745 nir_ssa_def *values[4] = {0}; 1746 for (int c = start; c < start + count; ++c) { 1747 /* Load output from variable. */ 1748 values[c - start] = nir_load_var(b, s->output_vars[slot][c]); 1749 /* Clear the variable (it is undefined after emit_vertex) */ 1750 nir_store_var(b, s->output_vars[slot][c], nir_ssa_undef(b, 1, 32), 0x1); 1751 } 1752 1753 nir_ssa_def *store_val = nir_vec(b, values, (unsigned)count); 1754 nir_store_shared(b, store_val, gs_emit_vtx_addr, 1755 .base = packed_location * 16 + start * 4, 1756 .align_mul = 4); 1757 } 1758 } 1759 1760 /* Calculate and store per-vertex primitive flags based on vertex counts: 1761 * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip) 1762 * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0) 1763 * - bit 2: always 1 (so that we can use it for determining vertex liveness) 1764 */ 1765 1766 nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1)); 1767 nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u)); 1768 1769 if (s->num_vertices_per_primitive == 3) { 1770 nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1); 1771 prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1))); 1772 } 1773 1774 nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 4u); 1775 nir_instr_remove(&intrin->instr); 1776 return true; 1777} 1778 1779static bool 1780lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s) 1781{ 1782 b->cursor = nir_before_instr(&intrin->instr); 1783 1784 /* These are not needed, we can simply remove them */ 1785 nir_instr_remove(&intrin->instr); 1786 return true; 1787} 1788 1789static bool 1790lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s) 1791{ 1792 b->cursor = nir_before_instr(&intrin->instr); 1793 1794 unsigned stream = nir_intrinsic_stream_id(intrin); 1795 if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) { 1796 nir_instr_remove(&intrin->instr); 1797 return true; 1798 } 1799 1800 s->found_out_vtxcnt[stream] = true; 1801 1802 /* Clear the primitive flags of non-emitted vertices */ 1803 if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out) 1804 ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s); 1805 1806 ngg_gs_shader_query(b, intrin, s); 1807 nir_instr_remove(&intrin->instr); 1808 return true; 1809} 1810 1811static bool 1812lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state) 1813{ 1814 lower_ngg_gs_state *s = (lower_ngg_gs_state *) state; 1815 1816 if (instr->type != nir_instr_type_intrinsic) 1817 return false; 1818 1819 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 1820 1821 if (intrin->intrinsic == nir_intrinsic_store_output) 1822 return lower_ngg_gs_store_output(b, intrin, s); 1823 else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter) 1824 return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s); 1825 else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter) 1826 return lower_ngg_gs_end_primitive_with_counter(b, intrin, s); 1827 else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count) 1828 return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s); 1829 1830 return false; 1831} 1832 1833static void 1834lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s) 1835{ 1836 nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s); 1837} 1838 1839static void 1840ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa_def *tid_in_tg, 1841 nir_ssa_def *exporter_tid_in_tg, nir_ssa_def *primflag_0, 1842 lower_ngg_gs_state *s) 1843{ 1844 nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims)); 1845 1846 /* Only bit 0 matters here - set it to 1 when the primitive should be null */ 1847 nir_ssa_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u)); 1848 1849 nir_ssa_def *vtx_indices[3] = {0}; 1850 vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg; 1851 if (s->num_vertices_per_primitive >= 2) 1852 vtx_indices[s->num_vertices_per_primitive - 2] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 1)); 1853 if (s->num_vertices_per_primitive == 3) 1854 vtx_indices[s->num_vertices_per_primitive - 3] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 2)); 1855 1856 if (s->num_vertices_per_primitive == 3) { 1857 /* API GS outputs triangle strips, but NGG HW understands triangles. 1858 * We already know the triangles due to how we set the primitive flags, but we need to 1859 * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept. 1860 */ 1861 1862 nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1)); 1863 if (!s->provoking_vertex_last) { 1864 vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd); 1865 vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd); 1866 } else { 1867 vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd); 1868 vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd); 1869 } 1870 } 1871 1872 nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim, false); 1873 nir_export_primitive_amd(b, arg); 1874 nir_pop_if(b, if_prim_export_thread); 1875} 1876 1877static void 1878ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def *tid_in_tg, 1879 nir_ssa_def *out_vtx_lds_addr, lower_ngg_gs_state *s) 1880{ 1881 nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx)); 1882 nir_ssa_def *exported_out_vtx_lds_addr = out_vtx_lds_addr; 1883 1884 if (!s->output_compile_time_known) { 1885 /* Vertex compaction. 1886 * The current thread will export a vertex that was live in another invocation. 1887 * Load the index of the vertex that the current thread will have to export. 1888 */ 1889 nir_ssa_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1); 1890 exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s); 1891 } 1892 1893 /* Remember proper bit sizes of output variables. */ 1894 uint8_t out_bitsizes[VARYING_SLOT_MAX]; 1895 memset(out_bitsizes, 32, VARYING_SLOT_MAX); 1896 nir_foreach_shader_out_variable(var, b->shader) { 1897 /* Check 8/16-bit. All others should be lowered to 32-bit already. */ 1898 unsigned bit_size = glsl_base_type_bit_size(glsl_get_base_type(glsl_without_array(var->type))); 1899 if (bit_size == 8 || bit_size == 16) 1900 out_bitsizes[var->data.location] = bit_size; 1901 } 1902 1903 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) { 1904 if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot))) 1905 continue; 1906 1907 gs_output_info *info = &s->output_info[slot]; 1908 if (!info->components_mask || info->stream != 0) 1909 continue; 1910 1911 unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot))); 1912 nir_io_semantics io_sem = { .location = slot, .num_slots = 1 }; 1913 1914 unsigned mask = info->components_mask; 1915 while (mask) { 1916 int start, count; 1917 u_bit_scan_consecutive_range(&mask, &start, &count); 1918 nir_ssa_def *load = 1919 nir_load_shared(b, count, 32, exported_out_vtx_lds_addr, 1920 .base = packed_location * 16 + start * 4, 1921 .align_mul = 4); 1922 1923 /* Convert to the expected bit size of the output variable. */ 1924 if (out_bitsizes[slot] != 32) 1925 load = nir_u2u(b, load, out_bitsizes[slot]); 1926 1927 nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .io_semantics = io_sem, 1928 .component = start, .write_mask = BITFIELD_MASK(count)); 1929 } 1930 } 1931 1932 nir_export_vertex_amd(b); 1933 nir_pop_if(b, if_vtx_export_thread); 1934} 1935 1936static void 1937ngg_gs_setup_vertex_compaction(nir_builder *b, nir_ssa_def *vertex_live, nir_ssa_def *tid_in_tg, 1938 nir_ssa_def *exporter_tid_in_tg, lower_ngg_gs_state *s) 1939{ 1940 assert(vertex_live->bit_size == 1); 1941 nir_if *if_vertex_live = nir_push_if(b, vertex_live); 1942 { 1943 /* Setup the vertex compaction. 1944 * Save the current thread's id for the thread which will export the current vertex. 1945 * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this. 1946 */ 1947 1948 nir_ssa_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s); 1949 nir_ssa_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg); 1950 nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1); 1951 } 1952 nir_pop_if(b, if_vertex_live); 1953} 1954 1955static nir_ssa_def * 1956ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *vtx_lds_addr, 1957 nir_ssa_def *max_num_out_vtx, lower_ngg_gs_state *s) 1958{ 1959 nir_ssa_def *zero = nir_imm_int(b, 0); 1960 1961 nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx)); 1962 nir_ssa_def *primflag_0 = nir_load_shared(b, 1, 8, vtx_lds_addr, .base = s->lds_offs_primflags, .align_mul = 4u); 1963 primflag_0 = nir_u2u32(b, primflag_0); 1964 nir_pop_if(b, if_outvtx_thread); 1965 1966 return nir_if_phi(b, primflag_0, zero); 1967} 1968 1969static void 1970ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s) 1971{ 1972 nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b); 1973 nir_ssa_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b); 1974 nir_ssa_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */ 1975 nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s); 1976 1977 if (s->output_compile_time_known) { 1978 /* When the output is compile-time known, the GS writes all possible vertices and primitives it can. 1979 * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs. 1980 */ 1981 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_zero(b, 1, 32))); 1982 nir_alloc_vertices_and_primitives_amd(b, max_vtxcnt, max_prmcnt); 1983 nir_pop_if(b, if_wave_0); 1984 } 1985 1986 /* Workgroup barrier: wait for all GS threads to finish */ 1987 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 1988 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); 1989 1990 nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag_0(b, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s); 1991 1992 if (s->output_compile_time_known) { 1993 ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s); 1994 ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s); 1995 return; 1996 } 1997 1998 /* When the output vertex count is not known at compile time: 1999 * There may be gaps between invocations that have live vertices, but NGG hardware 2000 * requires that the invocations that export vertices are packed (ie. compact). 2001 * To ensure this, we need to repack invocations that have a live vertex. 2002 */ 2003 nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size)); 2004 wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size); 2005 2006 nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations; 2007 nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index; 2008 2009 /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */ 2010 nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0)); 2011 max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0)); 2012 2013 /* Allocate export space. We currently don't compact primitives, just use the maximum number. */ 2014 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_zero(b, 1, 32))); 2015 nir_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt); 2016 nir_pop_if(b, if_wave_0); 2017 2018 /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */ 2019 ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s); 2020 2021 /* Workgroup barrier: wait for all LDS stores to finish. */ 2022 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 2023 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); 2024 2025 ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s); 2026 ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s); 2027} 2028 2029void 2030ac_nir_lower_ngg_gs(nir_shader *shader, 2031 unsigned wave_size, 2032 unsigned max_workgroup_size, 2033 unsigned esgs_ring_lds_bytes, 2034 unsigned gs_out_vtx_bytes, 2035 unsigned gs_total_out_vtx_bytes, 2036 bool provoking_vertex_last) 2037{ 2038 nir_function_impl *impl = nir_shader_get_entrypoint(shader); 2039 assert(impl); 2040 2041 lower_ngg_gs_state state = { 2042 .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size), 2043 .wave_size = wave_size, 2044 .lds_addr_gs_out_vtx = esgs_ring_lds_bytes, 2045 .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */), 2046 .lds_offs_primflags = gs_out_vtx_bytes, 2047 .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u, 2048 .provoking_vertex_last = provoking_vertex_last, 2049 }; 2050 2051 unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u; 2052 unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes; 2053 shader->info.shared_size = total_lds_bytes; 2054 2055 nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u); 2056 state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out && 2057 state.const_out_prmcnt[0] != -1; 2058 2059 if (!state.output_compile_time_known) 2060 state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx"); 2061 2062 if (shader->info.gs.output_primitive == SHADER_PRIM_POINTS) 2063 state.num_vertices_per_primitive = 1; 2064 else if (shader->info.gs.output_primitive == SHADER_PRIM_LINE_STRIP) 2065 state.num_vertices_per_primitive = 2; 2066 else if (shader->info.gs.output_primitive == SHADER_PRIM_TRIANGLE_STRIP) 2067 state.num_vertices_per_primitive = 3; 2068 else 2069 unreachable("Invalid GS output primitive."); 2070 2071 /* Extract the full control flow. It is going to be wrapped in an if statement. */ 2072 nir_cf_list extracted; 2073 nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body)); 2074 2075 nir_builder builder; 2076 nir_builder *b = &builder; /* This is to avoid the & */ 2077 nir_builder_init(b, impl); 2078 b->cursor = nir_before_cf_list(&impl->body); 2079 2080 /* Workgroup barrier: wait for ES threads */ 2081 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 2082 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); 2083 2084 /* Wrap the GS control flow. */ 2085 nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b)); 2086 2087 /* Create and initialize output variables */ 2088 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) { 2089 for (unsigned comp = 0; comp < 4; ++comp) { 2090 state.output_vars[slot][comp] = nir_local_variable_create(impl, glsl_uint_type(), "output"); 2091 } 2092 } 2093 2094 nir_cf_reinsert(&extracted, b->cursor); 2095 b->cursor = nir_after_cf_list(&if_gs_thread->then_list); 2096 nir_pop_if(b, if_gs_thread); 2097 2098 /* Lower the GS intrinsics */ 2099 lower_ngg_gs_intrinsics(shader, &state); 2100 b->cursor = nir_after_cf_list(&impl->body); 2101 2102 if (!state.found_out_vtxcnt[0]) { 2103 fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU."); 2104 abort(); 2105 } 2106 2107 /* Emit the finale sequence */ 2108 ngg_gs_finale(b, &state); 2109 nir_validate_shader(shader, "after emitting NGG GS"); 2110 2111 /* Cleanup */ 2112 nir_lower_vars_to_ssa(shader); 2113 nir_remove_dead_variables(shader, nir_var_function_temp, NULL); 2114 nir_metadata_preserve(impl, nir_metadata_none); 2115} 2116 2117static void 2118ms_store_prim_indices(nir_builder *b, 2119 nir_ssa_def *val, 2120 nir_ssa_def *offset_src, 2121 lower_ngg_ms_state *s) 2122{ 2123 assert(val->num_components <= 3); 2124 2125 if (!offset_src) 2126 offset_src = nir_imm_int(b, 0); 2127 2128 nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr); 2129} 2130 2131static nir_ssa_def * 2132ms_load_prim_indices(nir_builder *b, 2133 nir_ssa_def *offset_src, 2134 lower_ngg_ms_state *s) 2135{ 2136 if (!offset_src) 2137 offset_src = nir_imm_int(b, 0); 2138 2139 return nir_load_shared(b, 1, 8, offset_src, .base = s->layout.lds.indices_addr); 2140} 2141 2142static void 2143ms_store_num_prims(nir_builder *b, 2144 nir_ssa_def *store_val, 2145 lower_ngg_ms_state *s) 2146{ 2147 nir_ssa_def *addr = nir_imm_int(b, 0); 2148 nir_store_shared(b, nir_u2u32(b, store_val), addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); 2149} 2150 2151static nir_ssa_def * 2152ms_load_num_prims(nir_builder *b, 2153 lower_ngg_ms_state *s) 2154{ 2155 nir_ssa_def *addr = nir_imm_int(b, 0); 2156 return nir_load_shared(b, 1, 32, addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); 2157} 2158 2159static nir_ssa_def * 2160lower_ms_store_output(nir_builder *b, 2161 nir_intrinsic_instr *intrin, 2162 lower_ngg_ms_state *s) 2163{ 2164 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); 2165 nir_ssa_def *store_val = intrin->src[0].ssa; 2166 2167 /* Component makes no sense here. */ 2168 assert(nir_intrinsic_component(intrin) == 0); 2169 2170 if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) { 2171 /* Total number of primitives output by the mesh shader workgroup. 2172 * This can be read and written by any invocation any number of times. 2173 */ 2174 2175 /* Base, offset and component make no sense here. */ 2176 assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0); 2177 2178 ms_store_num_prims(b, store_val, s); 2179 } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) { 2180 /* Contrary to the name, these are not primitive indices, but 2181 * vertex indices for each vertex of the output primitives. 2182 * The Mesh NV API has these stored in a flat array. 2183 */ 2184 2185 nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa; 2186 ms_store_prim_indices(b, store_val, offset_src, s); 2187 } else { 2188 unreachable("Invalid mesh shader output"); 2189 } 2190 2191 return NIR_LOWER_INSTR_PROGRESS_REPLACE; 2192} 2193 2194static nir_ssa_def * 2195lower_ms_load_output(nir_builder *b, 2196 nir_intrinsic_instr *intrin, 2197 lower_ngg_ms_state *s) 2198{ 2199 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); 2200 2201 /* Component makes no sense here. */ 2202 assert(nir_intrinsic_component(intrin) == 0); 2203 2204 if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) { 2205 /* Base, offset and component make no sense here. */ 2206 assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0); 2207 2208 return ms_load_num_prims(b, s); 2209 } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) { 2210 nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa; 2211 nir_ssa_def *index = ms_load_prim_indices(b, offset_src, s); 2212 return nir_u2u(b, index, intrin->dest.ssa.bit_size); 2213 } 2214 2215 unreachable("Invalid mesh shader output"); 2216} 2217 2218static nir_ssa_def * 2219ms_arrayed_output_base_addr(nir_builder *b, 2220 nir_ssa_def *arr_index, 2221 unsigned driver_location, 2222 unsigned num_arrayed_outputs) 2223{ 2224 /* Address offset of the array item (vertex or primitive). */ 2225 unsigned arr_index_stride = num_arrayed_outputs * 16u; 2226 nir_ssa_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride); 2227 2228 /* IO address offset within the vertex or primitive data. */ 2229 unsigned io_offset = driver_location * 16u; 2230 nir_ssa_def *io_off = nir_imm_int(b, io_offset); 2231 2232 return nir_iadd_nuw(b, arr_index_off, io_off); 2233} 2234 2235static void 2236update_ms_output_info_slot(lower_ngg_ms_state *s, 2237 unsigned slot, unsigned base_off, 2238 uint32_t components_mask) 2239{ 2240 while (components_mask) { 2241 s->output_info[slot + base_off].components_mask |= components_mask & 0xF; 2242 2243 components_mask >>= 4; 2244 base_off++; 2245 } 2246} 2247 2248static void 2249update_ms_output_info(nir_intrinsic_instr *intrin, 2250 const ms_out_part *out, 2251 lower_ngg_ms_state *s) 2252{ 2253 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); 2254 nir_src *base_offset_src = nir_get_io_offset_src(intrin); 2255 uint32_t write_mask = nir_intrinsic_write_mask(intrin); 2256 unsigned component_offset = nir_intrinsic_component(intrin); 2257 2258 nir_ssa_def *store_val = intrin->src[0].ssa; 2259 write_mask = util_widen_mask(write_mask, DIV_ROUND_UP(store_val->bit_size, 32)); 2260 uint32_t components_mask = write_mask << component_offset; 2261 2262 if (nir_src_is_const(*base_offset_src)) { 2263 /* Simply mark the components of the current slot as used. */ 2264 unsigned base_off = nir_src_as_uint(*base_offset_src); 2265 update_ms_output_info_slot(s, io_sem.location, base_off, components_mask); 2266 } else { 2267 /* Indirect offset: mark the components of all slots as used. */ 2268 for (unsigned base_off = 0; base_off < io_sem.num_slots; ++base_off) 2269 update_ms_output_info_slot(s, io_sem.location, base_off, components_mask); 2270 } 2271} 2272 2273static nir_ssa_def * 2274regroup_store_val(nir_builder *b, nir_ssa_def *store_val) 2275{ 2276 /* Vulkan spec 15.1.4-15.1.5: 2277 * 2278 * The shader interface consists of output slots with 4x 32-bit components. 2279 * Small bitsize components consume the same space as 32-bit components, 2280 * but 64-bit ones consume twice as much. 2281 * 2282 * The same output slot may consist of components of different bit sizes. 2283 * Therefore for simplicity we don't store small bitsize components 2284 * contiguously, but pad them instead. In practice, they are converted to 2285 * 32-bit and then stored contiguously. 2286 */ 2287 2288 if (store_val->bit_size < 32) { 2289 assert(store_val->num_components <= 4); 2290 nir_ssa_def *comps[4] = {0}; 2291 for (unsigned c = 0; c < store_val->num_components; ++c) 2292 comps[c] = nir_u2u32(b, nir_channel(b, store_val, c)); 2293 return nir_vec(b, comps, store_val->num_components); 2294 } 2295 2296 return store_val; 2297} 2298 2299static nir_ssa_def * 2300regroup_load_val(nir_builder *b, nir_ssa_def *load, unsigned dest_bit_size) 2301{ 2302 if (dest_bit_size == load->bit_size) 2303 return load; 2304 2305 /* Small bitsize components are not stored contiguously, take care of that here. */ 2306 unsigned num_components = load->num_components; 2307 assert(num_components <= 4); 2308 nir_ssa_def *components[4] = {0}; 2309 for (unsigned i = 0; i < num_components; ++i) 2310 components[i] = nir_u2u(b, nir_channel(b, load, i), dest_bit_size); 2311 2312 return nir_vec(b, components, num_components); 2313} 2314 2315static const ms_out_part * 2316ms_get_out_layout_part(unsigned location, 2317 shader_info *info, 2318 ms_out_mode *out_mode, 2319 lower_ngg_ms_state *s) 2320{ 2321 uint64_t mask = BITFIELD64_BIT(location); 2322 2323 if (info->per_primitive_outputs & mask) { 2324 if (mask & s->layout.lds.prm_attr.mask) { 2325 *out_mode = ms_out_mode_lds; 2326 return &s->layout.lds.prm_attr; 2327 } else if (mask & s->layout.vram.prm_attr.mask) { 2328 *out_mode = ms_out_mode_vram; 2329 return &s->layout.vram.prm_attr; 2330 } else if (mask & s->layout.var.prm_attr.mask) { 2331 *out_mode = ms_out_mode_var; 2332 return &s->layout.var.prm_attr; 2333 } 2334 } else { 2335 if (mask & s->layout.lds.vtx_attr.mask) { 2336 *out_mode = ms_out_mode_lds; 2337 return &s->layout.lds.vtx_attr; 2338 } else if (mask & s->layout.vram.vtx_attr.mask) { 2339 *out_mode = ms_out_mode_vram; 2340 return &s->layout.vram.vtx_attr; 2341 } else if (mask & s->layout.var.vtx_attr.mask) { 2342 *out_mode = ms_out_mode_var; 2343 return &s->layout.var.vtx_attr; 2344 } 2345 } 2346 2347 unreachable("Couldn't figure out mesh shader output mode."); 2348} 2349 2350static void 2351ms_store_arrayed_output_intrin(nir_builder *b, 2352 nir_intrinsic_instr *intrin, 2353 lower_ngg_ms_state *s) 2354{ 2355 ms_out_mode out_mode; 2356 unsigned location = nir_intrinsic_io_semantics(intrin).location; 2357 const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s); 2358 update_ms_output_info(intrin, out, s); 2359 2360 /* We compact the LDS size (we don't reserve LDS space for outputs which can 2361 * be stored in variables), so we can't rely on the original driver_location. 2362 * Instead, we compute the first free location based on the output mask. 2363 */ 2364 unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location)); 2365 unsigned component_offset = nir_intrinsic_component(intrin); 2366 unsigned write_mask = nir_intrinsic_write_mask(intrin); 2367 unsigned num_outputs = util_bitcount64(out->mask); 2368 unsigned const_off = out->addr + component_offset * 4; 2369 2370 nir_ssa_def *store_val = regroup_store_val(b, intrin->src[0].ssa); 2371 nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; 2372 nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs); 2373 nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa; 2374 nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u); 2375 nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off); 2376 2377 if (out_mode == ms_out_mode_lds) { 2378 nir_store_shared(b, store_val, addr, .base = const_off, 2379 .write_mask = write_mask, .align_mul = 16, 2380 .align_offset = const_off % 16); 2381 } else if (out_mode == ms_out_mode_vram) { 2382 nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b); 2383 nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b); 2384 nir_store_buffer_amd(b, store_val, ring, addr, off, 2385 .base = const_off, 2386 .write_mask = write_mask, 2387 .memory_modes = nir_var_shader_out); 2388 } else if (out_mode == ms_out_mode_var) { 2389 if (store_val->bit_size > 32) { 2390 /* Split 64-bit store values to 32-bit components. */ 2391 store_val = nir_bitcast_vector(b, store_val, 32); 2392 /* Widen the write mask so it is in 32-bit components. */ 2393 write_mask = util_widen_mask(write_mask, store_val->bit_size / 32); 2394 } 2395 2396 u_foreach_bit(comp, write_mask) { 2397 nir_ssa_def *val = nir_channel(b, store_val, comp); 2398 unsigned idx = location * 4 + comp + component_offset; 2399 nir_store_var(b, s->out_variables[idx], val, 0x1); 2400 } 2401 } else { 2402 unreachable("Invalid MS output mode for store"); 2403 } 2404} 2405 2406static nir_ssa_def * 2407ms_load_arrayed_output(nir_builder *b, 2408 nir_ssa_def *arr_index, 2409 nir_ssa_def *base_offset, 2410 unsigned location, 2411 unsigned component_offset, 2412 unsigned num_components, 2413 unsigned load_bit_size, 2414 lower_ngg_ms_state *s) 2415{ 2416 ms_out_mode out_mode; 2417 const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s); 2418 2419 unsigned component_addr_off = component_offset * 4; 2420 unsigned num_outputs = util_bitcount64(out->mask); 2421 unsigned const_off = out->addr + component_offset * 4; 2422 2423 /* Use compacted driver location instead of the original. */ 2424 unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location)); 2425 2426 nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs); 2427 nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16); 2428 nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off); 2429 2430 if (out_mode == ms_out_mode_lds) { 2431 return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16, 2432 .align_offset = component_addr_off % 16, 2433 .base = const_off); 2434 } else if (out_mode == ms_out_mode_vram) { 2435 nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b); 2436 nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b); 2437 return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, 2438 .base = const_off, 2439 .memory_modes = nir_var_shader_out); 2440 } else if (out_mode == ms_out_mode_var) { 2441 nir_ssa_def *arr[8] = {0}; 2442 unsigned num_32bit_components = num_components * load_bit_size / 32; 2443 for (unsigned comp = 0; comp < num_32bit_components; ++comp) { 2444 unsigned idx = location * 4 + comp + component_addr_off; 2445 arr[comp] = nir_load_var(b, s->out_variables[idx]); 2446 } 2447 if (load_bit_size > 32) 2448 return nir_extract_bits(b, arr, 1, 0, num_components, load_bit_size); 2449 return nir_vec(b, arr, num_components); 2450 } else { 2451 unreachable("Invalid MS output mode for load"); 2452 } 2453} 2454 2455static nir_ssa_def * 2456ms_load_arrayed_output_intrin(nir_builder *b, 2457 nir_intrinsic_instr *intrin, 2458 lower_ngg_ms_state *s) 2459{ 2460 nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; 2461 nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa; 2462 2463 unsigned location = nir_intrinsic_io_semantics(intrin).location; 2464 unsigned component_offset = nir_intrinsic_component(intrin); 2465 unsigned bit_size = intrin->dest.ssa.bit_size; 2466 unsigned num_components = intrin->dest.ssa.num_components; 2467 unsigned load_bit_size = MAX2(bit_size, 32); 2468 2469 nir_ssa_def *load = 2470 ms_load_arrayed_output(b, arr_index, base_offset, location, component_offset, 2471 num_components, load_bit_size, s); 2472 2473 return regroup_load_val(b, load, bit_size); 2474} 2475 2476static nir_ssa_def * 2477lower_ms_load_workgroup_index(nir_builder *b, 2478 UNUSED nir_intrinsic_instr *intrin, 2479 lower_ngg_ms_state *s) 2480{ 2481 return s->workgroup_index; 2482} 2483 2484static nir_ssa_def * 2485update_ms_scoped_barrier(nir_builder *b, 2486 nir_intrinsic_instr *intrin, 2487 lower_ngg_ms_state *s) 2488{ 2489 /* Output loads and stores are lowered to shared memory access, 2490 * so we have to update the barriers to also reflect this. 2491 */ 2492 unsigned mem_modes = nir_intrinsic_memory_modes(intrin); 2493 if (mem_modes & nir_var_shader_out) 2494 mem_modes |= nir_var_mem_shared; 2495 else 2496 return NULL; 2497 2498 nir_intrinsic_set_memory_modes(intrin, mem_modes); 2499 2500 return NIR_LOWER_INSTR_PROGRESS; 2501} 2502 2503static nir_ssa_def * 2504lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state) 2505{ 2506 lower_ngg_ms_state *s = (lower_ngg_ms_state *) state; 2507 2508 if (instr->type != nir_instr_type_intrinsic) 2509 return NULL; 2510 2511 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 2512 2513 switch (intrin->intrinsic) { 2514 case nir_intrinsic_store_output: 2515 return lower_ms_store_output(b, intrin, s); 2516 case nir_intrinsic_load_output: 2517 return lower_ms_load_output(b, intrin, s); 2518 case nir_intrinsic_store_per_vertex_output: 2519 case nir_intrinsic_store_per_primitive_output: 2520 ms_store_arrayed_output_intrin(b, intrin, s); 2521 return NIR_LOWER_INSTR_PROGRESS_REPLACE; 2522 case nir_intrinsic_load_per_vertex_output: 2523 case nir_intrinsic_load_per_primitive_output: 2524 return ms_load_arrayed_output_intrin(b, intrin, s); 2525 case nir_intrinsic_scoped_barrier: 2526 return update_ms_scoped_barrier(b, intrin, s); 2527 case nir_intrinsic_load_workgroup_index: 2528 return lower_ms_load_workgroup_index(b, intrin, s); 2529 default: 2530 unreachable("Not a lowerable mesh shader intrinsic."); 2531 } 2532} 2533 2534static bool 2535filter_ms_intrinsic(const nir_instr *instr, 2536 UNUSED const void *st) 2537{ 2538 if (instr->type != nir_instr_type_intrinsic) 2539 return false; 2540 2541 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 2542 return intrin->intrinsic == nir_intrinsic_store_output || 2543 intrin->intrinsic == nir_intrinsic_load_output || 2544 intrin->intrinsic == nir_intrinsic_store_per_vertex_output || 2545 intrin->intrinsic == nir_intrinsic_load_per_vertex_output || 2546 intrin->intrinsic == nir_intrinsic_store_per_primitive_output || 2547 intrin->intrinsic == nir_intrinsic_load_per_primitive_output || 2548 intrin->intrinsic == nir_intrinsic_scoped_barrier || 2549 intrin->intrinsic == nir_intrinsic_load_workgroup_index; 2550} 2551 2552static void 2553lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s) 2554{ 2555 nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s); 2556} 2557 2558static void 2559ms_emit_arrayed_outputs(nir_builder *b, 2560 nir_ssa_def *invocation_index, 2561 uint64_t mask, 2562 lower_ngg_ms_state *s) 2563{ 2564 nir_ssa_def *zero = nir_imm_int(b, 0); 2565 2566 u_foreach_bit64(slot, mask) { 2567 /* Should not occour here, handled separately. */ 2568 assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES); 2569 2570 const nir_io_semantics io_sem = { .location = slot, .num_slots = 1 }; 2571 unsigned component_mask = s->output_info[slot].components_mask; 2572 2573 while (component_mask) { 2574 int start_comp = 0, num_components = 1; 2575 u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components); 2576 2577 nir_ssa_def *load = 2578 ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp, 2579 num_components, 32, s); 2580 2581 nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .component = start_comp, 2582 .io_semantics = io_sem); 2583 } 2584 } 2585} 2586 2587static void 2588emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s) 2589{ 2590 b->cursor = nir_before_cf_list(&b->impl->body); 2591 2592 /* Initialize NIR variables for same-invocation outputs. */ 2593 uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask; 2594 2595 u_foreach_bit64(slot, same_invocation_output_mask) { 2596 for (unsigned comp = 0; comp < 4; ++comp) { 2597 unsigned idx = slot * 4 + comp; 2598 s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output"); 2599 nir_store_var(b, s->out_variables[idx], nir_imm_int(b, 0), 0x1); 2600 } 2601 } 2602 2603 bool uses_workgroup_id = 2604 BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID) || 2605 BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX); 2606 2607 if (!uses_workgroup_id) 2608 return; 2609 2610 /* The HW doesn't support a proper workgroup index for vertex processing stages, 2611 * so we use the vertex ID which is equivalent to the index of the current workgroup 2612 * within the current dispatch. 2613 * 2614 * Due to the register programming of mesh shaders, this value is only filled for 2615 * the first invocation of the first wave. To let other waves know, we use LDS. 2616 */ 2617 nir_ssa_def *workgroup_index = nir_load_vertex_id_zero_base(b); 2618 2619 if (s->api_workgroup_size <= s->wave_size) { 2620 /* API workgroup is small, so we don't need to use LDS. */ 2621 s->workgroup_index = nir_read_first_invocation(b, workgroup_index); 2622 return; 2623 } 2624 2625 unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index; 2626 2627 nir_ssa_def *zero = nir_imm_int(b, 0); 2628 nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32); 2629 nir_ssa_def *loaded_workgroup_index = NULL; 2630 2631 /* Use elect to make sure only 1 invocation uses LDS. */ 2632 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); 2633 { 2634 nir_ssa_def *wave_id = nir_load_subgroup_id(b); 2635 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0)); 2636 { 2637 nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr); 2638 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, 2639 .memory_scope = NIR_SCOPE_WORKGROUP, 2640 .memory_semantics = NIR_MEMORY_ACQ_REL, 2641 .memory_modes = nir_var_mem_shared); 2642 } 2643 nir_push_else(b, if_wave_0); 2644 { 2645 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, 2646 .memory_scope = NIR_SCOPE_WORKGROUP, 2647 .memory_semantics = NIR_MEMORY_ACQ_REL, 2648 .memory_modes = nir_var_mem_shared); 2649 loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr); 2650 } 2651 nir_pop_if(b, if_wave_0); 2652 2653 workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index); 2654 } 2655 nir_pop_if(b, if_elected); 2656 2657 workgroup_index = nir_if_phi(b, workgroup_index, dont_care); 2658 s->workgroup_index = nir_read_first_invocation(b, workgroup_index); 2659} 2660 2661static void 2662set_nv_ms_final_output_counts(nir_builder *b, 2663 lower_ngg_ms_state *s, 2664 nir_ssa_def **out_num_prm, 2665 nir_ssa_def **out_num_vtx) 2666{ 2667 /* Limitations of the NV extension: 2668 * - Number of primitives can be written and read by any invocation, 2669 * so we have to store/load it to/from LDS to make sure the general case works. 2670 * - Number of vertices is not actually known, so we just always use the 2671 * maximum number here. 2672 */ 2673 nir_ssa_def *loaded_num_prm; 2674 nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32); 2675 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); 2676 { 2677 loaded_num_prm = ms_load_num_prims(b, s); 2678 } 2679 nir_pop_if(b, if_elected); 2680 loaded_num_prm = nir_if_phi(b, loaded_num_prm, dont_care); 2681 nir_ssa_def *num_prm = nir_read_first_invocation(b, loaded_num_prm); 2682 nir_ssa_def *num_vtx = nir_imm_int(b, b->shader->info.mesh.max_vertices_out); 2683 num_prm = nir_umin(b, num_prm, nir_imm_int(b, b->shader->info.mesh.max_primitives_out)); 2684 2685 /* If the shader doesn't actually create any primitives, don't allocate any output. */ 2686 num_vtx = nir_bcsel(b, nir_ieq_imm(b, num_prm, 0), nir_imm_int(b, 0), num_vtx); 2687 2688 /* Emit GS_ALLOC_REQ on Wave 0 to let the HW know the output size. */ 2689 nir_ssa_def *wave_id = nir_load_subgroup_id(b); 2690 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0)); 2691 { 2692 nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm); 2693 } 2694 nir_pop_if(b, if_wave_0); 2695 2696 *out_num_prm = num_prm; 2697 *out_num_vtx = num_vtx; 2698} 2699 2700static void 2701emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) 2702{ 2703 /* We assume there is always a single end block in the shader. */ 2704 nir_block *last_block = nir_impl_last_block(b->impl); 2705 b->cursor = nir_after_block(last_block); 2706 2707 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, 2708 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared); 2709 2710 nir_ssa_def *num_prm; 2711 nir_ssa_def *num_vtx; 2712 2713 set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx); 2714 2715 nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); 2716 2717 /* Load vertex/primitive attributes from shared memory and 2718 * emit store_output intrinsics for them. 2719 * 2720 * Contrary to the semantics of the API mesh shader, these are now 2721 * compliant with NGG HW semantics, meaning that these store the 2722 * current thread's vertex attributes in a way the HW can export. 2723 */ 2724 2725 /* Export vertices. */ 2726 nir_ssa_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx); 2727 nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex); 2728 { 2729 /* All per-vertex attributes. */ 2730 ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs, s); 2731 nir_export_vertex_amd(b); 2732 } 2733 nir_pop_if(b, if_has_output_vertex); 2734 2735 /* Export primitives. */ 2736 nir_ssa_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm); 2737 nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive); 2738 { 2739 /* Generic per-primitive attributes. */ 2740 ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs, s); 2741 2742 /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */ 2743 if (s->insert_layer_output) { 2744 nir_ssa_def *layer = nir_load_view_index(b); 2745 const nir_io_semantics io_sem = { .location = VARYING_SLOT_LAYER, .num_slots = 1 }; 2746 nir_store_output(b, layer, nir_imm_int(b, 0), .base = VARYING_SLOT_LAYER, .component = 0, .io_semantics = io_sem); 2747 b->shader->info.outputs_written |= VARYING_BIT_LAYER; 2748 b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER; 2749 } 2750 2751 /* Primitive connectivity data: describes which vertices the primitive uses. */ 2752 nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim); 2753 nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr); 2754 nir_ssa_def *indices[3]; 2755 nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u); 2756 2757 for (unsigned i = 0; i < s->vertices_per_prim; ++i) { 2758 indices[i] = nir_u2u32(b, nir_channel(b, indices_loaded, i)); 2759 indices[i] = nir_umin(b, indices[i], max_vtx_idx); 2760 } 2761 2762 nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, NULL, false); 2763 nir_export_primitive_amd(b, prim_exp_arg); 2764 } 2765 nir_pop_if(b, if_has_output_primitive); 2766} 2767 2768static void 2769handle_smaller_ms_api_workgroup(nir_builder *b, 2770 lower_ngg_ms_state *s) 2771{ 2772 if (s->api_workgroup_size >= s->hw_workgroup_size) 2773 return; 2774 2775 /* Handle barriers manually when the API workgroup 2776 * size is less than the HW workgroup size. 2777 * 2778 * The problem is that the real workgroup launched on NGG HW 2779 * will be larger than the size specified by the API, and the 2780 * extra waves need to keep up with barriers in the API waves. 2781 * 2782 * There are 2 different cases: 2783 * 1. The whole API workgroup fits in a single wave. 2784 * We can shrink the barriers to subgroup scope and 2785 * don't need to insert any extra ones. 2786 * 2. The API workgroup occupies multiple waves, but not 2787 * all. In this case, we emit code that consumes every 2788 * barrier on the extra waves. 2789 */ 2790 assert(s->hw_workgroup_size % s->wave_size == 0); 2791 bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size; 2792 bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size; 2793 bool need_additional_barriers = scan_barriers && !can_shrink_barriers; 2794 2795 unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves; 2796 unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size); 2797 2798 /* Scan the shader for workgroup barriers. */ 2799 if (scan_barriers) { 2800 bool has_any_workgroup_barriers = false; 2801 2802 nir_foreach_block(block, b->impl) { 2803 nir_foreach_instr_safe(instr, block) { 2804 if (instr->type != nir_instr_type_intrinsic) 2805 continue; 2806 2807 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 2808 bool is_workgroup_barrier = 2809 intrin->intrinsic == nir_intrinsic_scoped_barrier && 2810 nir_intrinsic_execution_scope(intrin) == NIR_SCOPE_WORKGROUP; 2811 2812 if (!is_workgroup_barrier) 2813 continue; 2814 2815 if (can_shrink_barriers) { 2816 /* Every API invocation runs in the first wave. 2817 * In this case, we can change the barriers to subgroup scope 2818 * and avoid adding additional barriers. 2819 */ 2820 nir_intrinsic_set_memory_scope(intrin, NIR_SCOPE_SUBGROUP); 2821 nir_intrinsic_set_execution_scope(intrin, NIR_SCOPE_SUBGROUP); 2822 } else { 2823 has_any_workgroup_barriers = true; 2824 } 2825 } 2826 } 2827 2828 need_additional_barriers &= has_any_workgroup_barriers; 2829 } 2830 2831 /* Extract the full control flow of the shader. */ 2832 nir_cf_list extracted; 2833 nir_cf_extract(&extracted, nir_before_cf_list(&b->impl->body), nir_after_cf_list(&b->impl->body)); 2834 b->cursor = nir_before_cf_list(&b->impl->body); 2835 2836 /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */ 2837 nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); 2838 nir_ssa_def *zero = nir_imm_int(b, 0); 2839 2840 if (need_additional_barriers) { 2841 /* First invocation stores 0 to number of API waves in flight. */ 2842 nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); 2843 { 2844 nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr); 2845 } 2846 nir_pop_if(b, if_first_in_workgroup); 2847 2848 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, 2849 .memory_scope = NIR_SCOPE_WORKGROUP, 2850 .memory_semantics = NIR_MEMORY_ACQ_REL, 2851 .memory_modes = nir_var_shader_out | nir_var_mem_shared); 2852 } 2853 2854 nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, s->api_workgroup_size)); 2855 nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation); 2856 { 2857 nir_cf_reinsert(&extracted, b->cursor); 2858 b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list); 2859 2860 if (need_additional_barriers) { 2861 /* One invocation in each API wave decrements the number of API waves in flight. */ 2862 nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1)); 2863 { 2864 nir_shared_atomic_add(b, 32, zero, nir_imm_int(b, -1u), .base = api_waves_in_flight_addr); 2865 } 2866 nir_pop_if(b, if_elected_again); 2867 2868 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, 2869 .memory_scope = NIR_SCOPE_WORKGROUP, 2870 .memory_semantics = NIR_MEMORY_ACQ_REL, 2871 .memory_modes = nir_var_shader_out | nir_var_mem_shared); 2872 } 2873 } 2874 nir_pop_if(b, if_has_api_ms_invocation); 2875 2876 if (need_additional_barriers) { 2877 /* Make sure that waves that don't run any API invocations execute 2878 * the same amount of barriers as those that do. 2879 * 2880 * We do this by executing a barrier until the number of API waves 2881 * in flight becomes zero. 2882 */ 2883 nir_ssa_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation); 2884 nir_ssa_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0); 2885 nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms); 2886 { 2887 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); 2888 { 2889 nir_loop *loop = nir_push_loop(b); 2890 { 2891 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, 2892 .memory_scope = NIR_SCOPE_WORKGROUP, 2893 .memory_semantics = NIR_MEMORY_ACQ_REL, 2894 .memory_modes = nir_var_shader_out | nir_var_mem_shared); 2895 2896 nir_ssa_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr); 2897 nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0)); 2898 { 2899 nir_jump(b, nir_jump_break); 2900 } 2901 nir_pop_if(b, if_break); 2902 } 2903 nir_pop_loop(b, loop); 2904 } 2905 nir_pop_if(b, if_elected); 2906 } 2907 nir_pop_if(b, if_wave_has_no_api_ms); 2908 } 2909} 2910 2911static void 2912ms_move_output(ms_out_part *from, ms_out_part *to) 2913{ 2914 uint64_t loc = util_logbase2_64(from->mask); 2915 uint64_t bit = BITFIELD64_BIT(loc); 2916 from->mask ^= bit; 2917 to->mask |= bit; 2918} 2919 2920static void 2921ms_calculate_arrayed_output_layout(ms_out_mem_layout *l, 2922 unsigned max_vertices, 2923 unsigned max_primitives) 2924{ 2925 uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16; 2926 uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16; 2927 l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16); 2928 l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size; 2929 2930 uint32_t vram_vtx_attr_size = util_bitcount64(l->vram.vtx_attr.mask) * max_vertices * 16; 2931 l->vram.prm_attr.addr = ALIGN(l->vram.vtx_attr.addr + vram_vtx_attr_size, 16); 2932} 2933 2934static ms_out_mem_layout 2935ms_calculate_output_layout(unsigned api_shared_size, 2936 uint64_t per_vertex_output_mask, 2937 uint64_t per_primitive_output_mask, 2938 uint64_t cross_invocation_output_access, 2939 unsigned max_vertices, 2940 unsigned max_primitives, 2941 unsigned vertices_per_prim) 2942{ 2943 uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access; 2944 uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access; 2945 2946 /* Shared memory used by the API shader. */ 2947 ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } }; 2948 2949 /* Outputs without cross-invocation access can be stored in variables. */ 2950 l.var.vtx_attr.mask = per_vertex_output_mask & ~lds_per_vertex_output_mask; 2951 l.var.prm_attr.mask = per_primitive_output_mask & ~lds_per_primitive_output_mask; 2952 2953 /* Workgroup information, see ms_workgroup_* for the layout. */ 2954 l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16); 2955 l.lds.total_size = l.lds.workgroup_info_addr + 16; 2956 2957 /* Per-vertex and per-primitive output attributes. 2958 * Outputs without cross-invocation access are not included here. 2959 * First, try to put all outputs into LDS (shared memory). 2960 * If they don't fit, try to move them to VRAM one by one. 2961 */ 2962 l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16); 2963 l.lds.vtx_attr.mask = lds_per_vertex_output_mask; 2964 l.lds.prm_attr.mask = lds_per_primitive_output_mask; 2965 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives); 2966 2967 /* NGG shaders can only address up to 32K LDS memory. 2968 * The spec requires us to allow the application to use at least up to 28K 2969 * shared memory. Additionally, we reserve 2K for driver internal use 2970 * (eg. primitive indices and such, see below). 2971 * 2972 * Move the outputs that do not fit LDS, to VRAM. 2973 * Start with per-primitive attributes, because those are grouped at the end. 2974 */ 2975 while (l.lds.total_size >= 30 * 1024) { 2976 if (l.lds.prm_attr.mask) 2977 ms_move_output(&l.lds.prm_attr, &l.vram.prm_attr); 2978 else if (l.lds.vtx_attr.mask) 2979 ms_move_output(&l.lds.vtx_attr, &l.vram.vtx_attr); 2980 else 2981 unreachable("API shader uses too much shared memory."); 2982 2983 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives); 2984 } 2985 2986 /* Indices: flat array of 8-bit vertex indices for each primitive. */ 2987 l.lds.indices_addr = ALIGN(l.lds.total_size, 16); 2988 l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim; 2989 2990 /* NGG is only allowed to address up to 32K of LDS. */ 2991 assert(l.lds.total_size <= 32 * 1024); 2992 return l; 2993} 2994 2995void 2996ac_nir_lower_ngg_ms(nir_shader *shader, 2997 bool *out_needs_scratch_ring, 2998 unsigned wave_size, 2999 bool multiview) 3000{ 3001 unsigned vertices_per_prim = 3002 num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type); 3003 3004 uint64_t special_outputs = 3005 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES); 3006 uint64_t per_vertex_outputs = 3007 shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~special_outputs; 3008 uint64_t per_primitive_outputs = 3009 shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs; 3010 3011 /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */ 3012 uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access | 3013 shader->info.outputs_accessed_indirectly; 3014 3015 unsigned max_vertices = shader->info.mesh.max_vertices_out; 3016 unsigned max_primitives = shader->info.mesh.max_primitives_out; 3017 3018 ms_out_mem_layout layout = 3019 ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs, 3020 cross_invocation_access, max_vertices, max_primitives, vertices_per_prim); 3021 3022 shader->info.shared_size = layout.lds.total_size; 3023 *out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask; 3024 3025 /* The workgroup size that is specified by the API shader may be different 3026 * from the size of the workgroup that actually runs on the HW, due to the 3027 * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed. 3028 * 3029 * Therefore, we must make sure that when the API workgroup size is smaller, 3030 * we don't run the API shader on more HW invocations than is necessary. 3031 */ 3032 unsigned api_workgroup_size = shader->info.workgroup_size[0] * 3033 shader->info.workgroup_size[1] * 3034 shader->info.workgroup_size[2]; 3035 3036 unsigned hw_workgroup_size = 3037 ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size); 3038 3039 lower_ngg_ms_state state = { 3040 .layout = layout, 3041 .wave_size = wave_size, 3042 .per_vertex_outputs = per_vertex_outputs, 3043 .per_primitive_outputs = per_primitive_outputs, 3044 .vertices_per_prim = vertices_per_prim, 3045 .api_workgroup_size = api_workgroup_size, 3046 .hw_workgroup_size = hw_workgroup_size, 3047 .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER), 3048 }; 3049 3050 nir_function_impl *impl = nir_shader_get_entrypoint(shader); 3051 assert(impl); 3052 3053 nir_builder builder; 3054 nir_builder *b = &builder; /* This is to avoid the & */ 3055 nir_builder_init(b, impl); 3056 b->cursor = nir_before_cf_list(&impl->body); 3057 3058 handle_smaller_ms_api_workgroup(b, &state); 3059 emit_ms_prelude(b, &state); 3060 nir_metadata_preserve(impl, nir_metadata_none); 3061 3062 lower_ms_intrinsics(shader, &state); 3063 3064 emit_ms_finale(b, &state); 3065 nir_metadata_preserve(impl, nir_metadata_none); 3066 3067 /* Cleanup */ 3068 nir_lower_vars_to_ssa(shader); 3069 nir_remove_dead_variables(shader, nir_var_function_temp, NULL); 3070 nir_lower_alu_to_scalar(shader, NULL, NULL); 3071 nir_lower_phis_to_scalar(shader, true); 3072 3073 nir_validate_shader(shader, "after emitting NGG MS"); 3074} 3075