1/*
2 * Copyright 2017 Advanced Micro Devices, Inc.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24#include "ac_llvm_cull.h"
25#include "si_pipe.h"
26#include "si_query.h"
27#include "si_shader_internal.h"
28#include "sid.h"
29#include "util/u_memory.h"
30#include "util/u_prim.h"
31
32static LLVMValueRef get_wave_id_in_tg(struct si_shader_context *ctx)
33{
34   return si_unpack_param(ctx, ctx->args.merged_wave_info, 24, 4);
35}
36
37static LLVMValueRef get_tgsize(struct si_shader_context *ctx)
38{
39   return si_unpack_param(ctx, ctx->args.merged_wave_info, 28, 4);
40}
41
42LLVMValueRef gfx10_get_thread_id_in_tg(struct si_shader_context *ctx)
43{
44   LLVMBuilderRef builder = ctx->ac.builder;
45   LLVMValueRef tmp;
46   tmp = LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
47                      LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, false), "");
48   return LLVMBuildAdd(builder, tmp, ac_get_thread_id(&ctx->ac), "");
49}
50
51static LLVMValueRef ngg_get_vtx_cnt(struct si_shader_context *ctx)
52{
53   return si_unpack_param(ctx, ctx->args.gs_tg_info, 12, 9);
54}
55
56static LLVMValueRef ngg_get_prim_cnt(struct si_shader_context *ctx)
57{
58   return si_unpack_param(ctx, ctx->args.gs_tg_info, 22, 9);
59}
60
61static LLVMValueRef ngg_get_ordered_id(struct si_shader_context *ctx)
62{
63   return si_unpack_param(ctx, ctx->args.gs_tg_info, 0, 12);
64}
65
66static LLVMValueRef ngg_get_query_buf(struct si_shader_context *ctx)
67{
68   LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->internal_bindings);
69
70   return ac_build_load_to_sgpr(&ctx->ac, buf_ptr,
71                                LLVMConstInt(ctx->ac.i32, SI_GS_QUERY_BUF, false));
72}
73
74static LLVMValueRef ngg_get_emulated_counters_buf(struct si_shader_context *ctx)
75{
76   LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->internal_bindings);
77
78   return ac_build_load_to_sgpr(&ctx->ac, buf_ptr,
79                                LLVMConstInt(ctx->ac.i32, SI_GS_QUERY_EMULATED_COUNTERS_BUF, false));
80}
81
82/**
83 * Return the number of vertices as a constant in \p num_vertices,
84 * and return a more precise value as LLVMValueRef from the function.
85 */
86static LLVMValueRef ngg_get_vertices_per_prim(struct si_shader_context *ctx, unsigned *num_vertices)
87{
88   const struct si_shader_info *info = &ctx->shader->selector->info;
89
90   if (ctx->stage == MESA_SHADER_GEOMETRY) {
91      *num_vertices = u_vertices_per_prim(info->base.gs.output_primitive);
92      return LLVMConstInt(ctx->ac.i32, *num_vertices, false);
93   } else if (ctx->stage == MESA_SHADER_VERTEX) {
94      if (info->base.vs.blit_sgprs_amd) {
95         /* Blits always use axis-aligned rectangles with 3 vertices. */
96         *num_vertices = 3;
97         return LLVMConstInt(ctx->ac.i32, 3, 0);
98      } else if (ctx->shader->key.ge.opt.ngg_culling & SI_NGG_CULL_LINES) {
99         *num_vertices = 2;
100         return LLVMConstInt(ctx->ac.i32, 2, 0);
101      } else {
102         /* We always build up all three indices for the prim export
103          * independent of the primitive type. The additional garbage
104          * data shouldn't hurt. This is used by exports and streamout.
105          */
106         *num_vertices = 3;
107
108         /* Extract OUTPRIM field. */
109         LLVMValueRef num = GET_FIELD(ctx, GS_STATE_OUTPRIM);
110         return LLVMBuildAdd(ctx->ac.builder, num, ctx->ac.i32_1, "");
111      }
112   } else {
113      assert(ctx->stage == MESA_SHADER_TESS_EVAL);
114
115      if (info->base.tess.point_mode)
116         *num_vertices = 1;
117      else if (info->base.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES)
118         *num_vertices = 2;
119      else
120         *num_vertices = 3;
121
122      return LLVMConstInt(ctx->ac.i32, *num_vertices, false);
123   }
124}
125
126bool gfx10_ngg_export_prim_early(struct si_shader *shader)
127{
128   struct si_shader_selector *sel = shader->selector;
129
130   assert(shader->key.ge.as_ngg && !shader->key.ge.as_es);
131
132   return sel->stage != MESA_SHADER_GEOMETRY &&
133          !gfx10_ngg_writes_user_edgeflags(shader);
134}
135
136void gfx10_ngg_build_sendmsg_gs_alloc_req(struct si_shader_context *ctx)
137{
138   /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
139   if (gfx10_is_ngg_passthrough(ctx->shader) &&
140       ctx->screen->info.family >= CHIP_NAVI23)
141      return;
142
143   ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), ngg_get_vtx_cnt(ctx),
144                                 ngg_get_prim_cnt(ctx));
145}
146
147void gfx10_ngg_build_export_prim(struct si_shader_context *ctx, LLVMValueRef user_edgeflags[3],
148                                 LLVMValueRef prim_passthrough)
149{
150   LLVMBuilderRef builder = ctx->ac.builder;
151
152   if (gfx10_is_ngg_passthrough(ctx->shader) || ctx->shader->key.ge.opt.ngg_culling) {
153      ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 6001);
154      {
155         struct ac_ngg_prim prim = {};
156
157         if (prim_passthrough)
158            prim.passthrough = prim_passthrough;
159         else
160            prim.passthrough = ac_get_arg(&ctx->ac, ctx->args.gs_vtx_offset[0]);
161
162         /* This is only used with NGG culling, which returns the NGG
163          * passthrough prim export encoding.
164          */
165         if (gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
166            unsigned all_bits_no_edgeflags = ~SI_NGG_PRIM_EDGE_FLAG_BITS;
167            LLVMValueRef edgeflags = LLVMConstInt(ctx->ac.i32, all_bits_no_edgeflags, 0);
168
169            unsigned num_vertices;
170            ngg_get_vertices_per_prim(ctx, &num_vertices);
171
172            for (unsigned i = 0; i < num_vertices; i++) {
173               unsigned shift = 9 + i * 10;
174               LLVMValueRef edge;
175
176               edge = LLVMBuildLoad2(builder, ctx->ac.i1, user_edgeflags[i], "");
177               edge = LLVMBuildZExt(builder, edge, ctx->ac.i32, "");
178               edge = LLVMBuildShl(builder, edge, LLVMConstInt(ctx->ac.i32, shift, 0), "");
179               edgeflags = LLVMBuildOr(builder, edgeflags, edge, "");
180            }
181            prim.passthrough = LLVMBuildAnd(builder, prim.passthrough, edgeflags, "");
182         }
183
184         ac_build_export_prim(&ctx->ac, &prim);
185      }
186      ac_build_endif(&ctx->ac, 6001);
187      return;
188   }
189
190   ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 6001);
191   {
192      struct ac_ngg_prim prim = {};
193
194      ngg_get_vertices_per_prim(ctx, &prim.num_vertices);
195
196      prim.isnull = ctx->ac.i1false;
197
198      if (gfx10_edgeflags_have_effect(ctx->shader))
199         prim.edgeflags = ac_pack_edgeflags_for_export(&ctx->ac, &ctx->args);
200      else
201         prim.edgeflags = ctx->ac.i32_0;
202
203      for (unsigned i = 0; i < prim.num_vertices; ++i)
204         prim.index[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[i / 2], (i & 1) * 16, 16);
205
206      if (gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
207         LLVMValueRef edgeflags = ctx->ac.i32_0;
208
209         for (unsigned i = 0; i < prim.num_vertices; ++i) {
210            LLVMValueRef edge;
211
212            edge = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.i1, user_edgeflags[i], "");
213            edge = LLVMBuildZExt(ctx->ac.builder, edge, ctx->ac.i32, "");
214            edge = LLVMBuildShl(ctx->ac.builder, edge, LLVMConstInt(ctx->ac.i32, 9 + i*10, 0), "");
215            edgeflags = LLVMBuildOr(ctx->ac.builder, edgeflags, edge, "");
216         }
217         prim.edgeflags = LLVMBuildAnd(ctx->ac.builder, prim.edgeflags, edgeflags, "");
218      }
219
220      ac_build_export_prim(&ctx->ac, &prim);
221   }
222   ac_build_endif(&ctx->ac, 6001);
223}
224
225static void build_streamout_vertex(struct si_shader_context *ctx, LLVMValueRef *so_buffer,
226                                   LLVMValueRef *wg_offset_dw, unsigned stream,
227                                   LLVMValueRef offset_vtx, LLVMValueRef vertexptr)
228{
229   struct si_shader_info *info = &ctx->shader->selector->info;
230   struct pipe_stream_output_info *so = &ctx->so;
231   LLVMBuilderRef builder = ctx->ac.builder;
232   LLVMValueRef offset[4] = {};
233   LLVMValueRef tmp;
234
235   for (unsigned buffer = 0; buffer < 4; ++buffer) {
236      if (!wg_offset_dw[buffer])
237         continue;
238
239      tmp = LLVMBuildMul(builder, offset_vtx, LLVMConstInt(ctx->ac.i32, so->stride[buffer], false),
240                         "");
241      tmp = LLVMBuildAdd(builder, wg_offset_dw[buffer], tmp, "");
242      offset[buffer] = LLVMBuildShl(builder, tmp, LLVMConstInt(ctx->ac.i32, 2, false), "");
243   }
244
245   for (unsigned i = 0; i < so->num_outputs; ++i) {
246      if (so->output[i].stream != stream)
247         continue;
248
249      unsigned reg = so->output[i].register_index;
250      struct si_shader_output_values out;
251      out.semantic = info->output_semantic[reg];
252
253      for (unsigned comp = 0; comp < 4; comp++) {
254         tmp = ac_build_gep0(&ctx->ac, vertexptr, LLVMConstInt(ctx->ac.i32, 4 * reg + comp, false));
255         out.values[comp] = LLVMBuildLoad(builder, tmp, "");
256         out.vertex_streams = info->output_streams[reg];
257      }
258
259      si_llvm_streamout_store_output(ctx, so_buffer, offset, &so->output[i], &out);
260   }
261}
262
263struct ngg_streamout {
264   LLVMValueRef num_vertices;
265
266   /* per-thread data */
267   LLVMValueRef prim_enable[4]; /* i1 per stream */
268   LLVMValueRef vertices[3];    /* [N x i32] addrspace(LDS)* */
269
270   /* Output */
271   LLVMValueRef emit[4]; /* per-stream emitted primitives (only valid for used streams) */
272};
273
274/**
275 * Build streamout logic.
276 *
277 * Implies a barrier.
278 *
279 * Writes number of emitted primitives to gs_ngg_scratch[4:8].
280 *
281 * Clobbers gs_ngg_scratch[8:].
282 */
283static void build_streamout(struct si_shader_context *ctx, struct ngg_streamout *nggso)
284{
285   struct si_shader_info *info = &ctx->shader->selector->info;
286   struct pipe_stream_output_info *so = &ctx->so;
287   LLVMBuilderRef builder = ctx->ac.builder;
288   LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->internal_bindings);
289   LLVMValueRef tid = gfx10_get_thread_id_in_tg(ctx);
290   LLVMValueRef tmp, tmp2;
291   LLVMValueRef i32_2 = LLVMConstInt(ctx->ac.i32, 2, false);
292   LLVMValueRef i32_4 = LLVMConstInt(ctx->ac.i32, 4, false);
293   LLVMValueRef i32_8 = LLVMConstInt(ctx->ac.i32, 8, false);
294   LLVMValueRef so_buffer[4] = {};
295   unsigned max_num_vertices = 1 + (nggso->vertices[1] ? 1 : 0) + (nggso->vertices[2] ? 1 : 0);
296   LLVMValueRef prim_stride_dw[4] = {};
297   LLVMValueRef prim_stride_dw_vgpr = LLVMGetUndef(ctx->ac.i32);
298   int stream_for_buffer[4] = {-1, -1, -1, -1};
299   unsigned bufmask_for_stream[4] = {};
300   bool isgs = ctx->stage == MESA_SHADER_GEOMETRY;
301   unsigned scratch_emit_base = isgs ? 4 : 0;
302   LLVMValueRef scratch_emit_basev = isgs ? i32_4 : ctx->ac.i32_0;
303   unsigned scratch_offset_base = isgs ? 8 : 4;
304   LLVMValueRef scratch_offset_basev = isgs ? i32_8 : i32_4;
305
306   /* Determine the mapping of streamout buffers to vertex streams. */
307   for (unsigned i = 0; i < so->num_outputs; ++i) {
308      unsigned buf = so->output[i].output_buffer;
309      unsigned stream = so->output[i].stream;
310      assert(stream_for_buffer[buf] < 0 || stream_for_buffer[buf] == stream);
311      stream_for_buffer[buf] = stream;
312      bufmask_for_stream[stream] |= 1 << buf;
313   }
314
315   for (unsigned buffer = 0; buffer < 4; ++buffer) {
316      if (stream_for_buffer[buffer] == -1)
317         continue;
318
319      assert(so->stride[buffer]);
320
321      tmp = LLVMConstInt(ctx->ac.i32, so->stride[buffer], false);
322      prim_stride_dw[buffer] = LLVMBuildMul(builder, tmp, nggso->num_vertices, "");
323      prim_stride_dw_vgpr =
324         ac_build_writelane(&ctx->ac, prim_stride_dw_vgpr, prim_stride_dw[buffer],
325                            LLVMConstInt(ctx->ac.i32, buffer, false));
326
327      so_buffer[buffer] = ac_build_load_to_sgpr(
328         &ctx->ac, buf_ptr, LLVMConstInt(ctx->ac.i32, SI_VS_STREAMOUT_BUF0 + buffer, false));
329   }
330
331   tmp = LLVMBuildICmp(builder, LLVMIntEQ, get_wave_id_in_tg(ctx), ctx->ac.i32_0, "");
332   ac_build_ifcc(&ctx->ac, tmp, 5200);
333   {
334      LLVMTypeRef gdsptr = LLVMPointerType(ctx->ac.i32, AC_ADDR_SPACE_GDS);
335      LLVMValueRef gdsbase = LLVMBuildIntToPtr(builder, ctx->ac.i32_0, gdsptr, "");
336
337      /* Advance the streamout offsets in GDS. */
338      LLVMValueRef offsets_vgpr = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, "");
339      LLVMValueRef generated_by_stream_vgpr = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, "");
340
341      tmp = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), i32_4, "");
342      ac_build_ifcc(&ctx->ac, tmp, 5210);
343      {
344         if (isgs) {
345            tmp = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tid);
346            tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
347         } else {
348            tmp = ac_build_writelane(&ctx->ac, ctx->ac.i32_0, ngg_get_prim_cnt(ctx), ctx->ac.i32_0);
349         }
350         LLVMBuildStore(builder, tmp, generated_by_stream_vgpr);
351
352         unsigned swizzle[4];
353         int unused_stream = -1;
354         for (unsigned stream = 0; stream < 4; ++stream) {
355            if (!info->num_stream_output_components[stream]) {
356               unused_stream = stream;
357               break;
358            }
359         }
360         for (unsigned buffer = 0; buffer < 4; ++buffer) {
361            if (stream_for_buffer[buffer] >= 0) {
362               swizzle[buffer] = stream_for_buffer[buffer];
363            } else {
364               assert(unused_stream >= 0);
365               swizzle[buffer] = unused_stream;
366            }
367         }
368
369         tmp = ac_build_quad_swizzle(&ctx->ac, tmp, swizzle[0], swizzle[1], swizzle[2], swizzle[3]);
370         tmp = LLVMBuildMul(builder, tmp, prim_stride_dw_vgpr, "");
371
372         LLVMValueRef args[8] = {
373            LLVMBuildIntToPtr(builder, ngg_get_ordered_id(ctx), gdsptr, ""),
374            ctx->ac.i32_0,                             /* value to add */
375            ctx->ac.i32_0,                             /* ordering */
376            ctx->ac.i32_0,                             /* scope */
377            ctx->ac.i1false,                           /* isVolatile */
378            LLVMConstInt(ctx->ac.i32, 1 << 24, false), /* OA index, bits 24+: lane count */
379            ctx->ac.i1true,                            /* wave release */
380            ctx->ac.i1true,                            /* wave done */
381         };
382
383         if (ctx->screen->info.gfx_level >= GFX11) {
384            /* Gfx11 GDS instructions only operate on the first active lane. All other lanes are
385             * ignored. So are their EXEC bits. This uses the mutex feature of ds_ordered_count
386             * to emulate a multi-dword atomic.
387             *
388             * This is the expected code:
389             *    ds_ordered_count release=0 done=0   // lock mutex
390             *    ds_add_rtn_u32 dwords_written0
391             *    ds_add_rtn_u32 dwords_written1
392             *    ds_add_rtn_u32 dwords_written2
393             *    ds_add_rtn_u32 dwords_written3
394             *    ds_ordered_count release=1 done=1   // unlock mutex
395             *
396             * TODO: Increment GDS_STRMOUT registers instead of GDS memory.
397             */
398            LLVMValueRef dwords_written[4] = {tmp, tmp, tmp, tmp};
399
400            /* Move all 4 VGPRs from other lanes to lane 0. */
401            for (unsigned i = 1; i < 4; i++) {
402               if (ctx->shader->selector->info.base.xfb_stride[i])
403                  dwords_written[i] = ac_build_quad_swizzle(&ctx->ac, tmp, i, i, i, i);
404            }
405
406            /* Set release=0 to start a GDS mutex. Set done=0 because it's not the last one. */
407            args[6] = args[7] = ctx->ac.i1false;
408            ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.ds.ordered.add", ctx->ac.i32,
409                               args, ARRAY_SIZE(args), 0);
410            ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
411
412            for (unsigned i = 0; i < 4; i++) {
413               if (ctx->shader->selector->info.base.xfb_stride[i]) {
414                  LLVMValueRef gds_ptr =
415                     ac_build_gep_ptr(&ctx->ac, gdsbase, LLVMConstInt(ctx->ac.i32, i, 0));
416
417                  dwords_written[i] = LLVMBuildAtomicRMW(builder, LLVMAtomicRMWBinOpAdd,
418                                                         gds_ptr, dwords_written[i],
419                                                         LLVMAtomicOrderingMonotonic, false);
420               }
421            }
422
423            /* TODO: This might not be needed if GDS executes instructions in order. */
424            ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
425
426            /* Set release=1 to end a GDS mutex. Set done=1 because it's the last one. */
427            args[6] = args[7] = ctx->ac.i1true;
428            ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.ds.ordered.add", ctx->ac.i32,
429                               args, ARRAY_SIZE(args), 0);
430
431            tmp = dwords_written[0];
432            for (unsigned i = 1; i < 4; i++) {
433               if (ctx->shader->selector->info.base.xfb_stride[i]) {
434                  dwords_written[i] = ac_build_readlane(&ctx->ac, dwords_written[i], ctx->ac.i32_0);
435                  tmp = ac_build_writelane(&ctx->ac, tmp, dwords_written[i], LLVMConstInt(ctx->ac.i32, i, 0));
436               }
437            }
438         } else {
439            args[1] = tmp; /* value to add */
440            args[5] = LLVMConstInt(ctx->ac.i32, 4 << 24, false), /* bits 24+: lane count */
441
442            tmp = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.ds.ordered.add", ctx->ac.i32,
443                                     args, ARRAY_SIZE(args), 0);
444         }
445
446         /* Keep offsets in a VGPR for quick retrieval via readlane by
447          * the first wave for bounds checking, and also store in LDS
448          * for retrieval by all waves later. */
449         LLVMBuildStore(builder, tmp, offsets_vgpr);
450
451         tmp2 = LLVMBuildAdd(builder, ac_get_thread_id(&ctx->ac), scratch_offset_basev, "");
452         tmp2 = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tmp2);
453         LLVMBuildStore(builder, tmp, tmp2);
454      }
455      ac_build_endif(&ctx->ac, 5210);
456
457      /* Determine the max emit per buffer. This is done via the SALU, in part
458       * because LLVM can't generate divide-by-multiply if we try to do this
459       * via VALU with one lane per buffer.
460       */
461      LLVMValueRef max_emit[4] = {};
462      for (unsigned buffer = 0; buffer < 4; ++buffer) {
463         if (stream_for_buffer[buffer] == -1)
464            continue;
465
466         LLVMValueRef bufsize_dw = LLVMBuildLShr(
467            builder, LLVMBuildExtractElement(builder, so_buffer[buffer], i32_2, ""), i32_2, "");
468
469         tmp = LLVMBuildLoad2(builder, ctx->ac.i32, offsets_vgpr, "");
470         LLVMValueRef offset_dw =
471            ac_build_readlane(&ctx->ac, tmp, LLVMConstInt(ctx->ac.i32, buffer, false));
472
473         tmp = LLVMBuildSub(builder, bufsize_dw, offset_dw, "");
474         tmp = LLVMBuildUDiv(builder, tmp, prim_stride_dw[buffer], "");
475
476         tmp2 = LLVMBuildICmp(builder, LLVMIntULT, bufsize_dw, offset_dw, "");
477         max_emit[buffer] = LLVMBuildSelect(builder, tmp2, ctx->ac.i32_0, tmp, "");
478      }
479
480      /* Determine the number of emitted primitives per stream and fixup the
481       * GDS counter if necessary.
482       *
483       * This is complicated by the fact that a single stream can emit to
484       * multiple buffers (but luckily not vice versa).
485       */
486      LLVMValueRef emit_vgpr = ctx->ac.i32_0;
487
488      for (unsigned stream = 0; stream < 4; ++stream) {
489         if (!info->num_stream_output_components[stream])
490            continue;
491
492         tmp = LLVMBuildLoad2(builder, ctx->ac.i32, generated_by_stream_vgpr, "");
493         LLVMValueRef generated =
494            ac_build_readlane(&ctx->ac, tmp, LLVMConstInt(ctx->ac.i32, stream, false));
495
496         LLVMValueRef emit = generated;
497         for (unsigned buffer = 0; buffer < 4; ++buffer) {
498            if (stream_for_buffer[buffer] == stream)
499               emit = ac_build_umin(&ctx->ac, emit, max_emit[buffer]);
500         }
501
502         emit_vgpr =
503            ac_build_writelane(&ctx->ac, emit_vgpr, emit, LLVMConstInt(ctx->ac.i32, stream, false));
504
505         /* Fixup the offset using a plain GDS atomic if we overflowed. */
506         tmp = LLVMBuildICmp(builder, LLVMIntULT, emit, generated, "");
507         ac_build_ifcc(&ctx->ac, tmp, 5221); /* scalar branch */
508         tmp = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i32, bufmask_for_stream[stream], false),
509                             ac_get_thread_id(&ctx->ac), "");
510         tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
511         ac_build_ifcc(&ctx->ac, tmp, 5222);
512         {
513            tmp = LLVMBuildSub(builder, generated, emit, "");
514            tmp = LLVMBuildMul(builder, tmp, prim_stride_dw_vgpr, "");
515
516            if (ctx->screen->info.gfx_level >= GFX11) {
517               /* Gfx11 GDS instructions only operate on the first active lane.
518                * This is an unrolled waterfall loop. We only get here when we overflow,
519                * so it doesn't have to be fast.
520                */
521               for (unsigned i = 0; i < 4; i++) {
522                  if (bufmask_for_stream[stream] & BITFIELD_BIT(i)) {
523                     LLVMValueRef index = LLVMConstInt(ctx->ac.i32, i, 0);
524
525                     ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, index, ""), 0);
526                     LLVMBuildAtomicRMW(builder, LLVMAtomicRMWBinOpSub,
527                                        LLVMBuildGEP(builder, gdsbase, &index, 1, ""),
528                                        tmp, LLVMAtomicOrderingMonotonic, false);
529                     ac_build_endif(&ctx->ac, 0);
530                  }
531               }
532            } else {
533               LLVMBuildAtomicRMW(builder, LLVMAtomicRMWBinOpSub,
534                                  LLVMBuildGEP(builder, gdsbase, &tid, 1, ""),
535                                  tmp, LLVMAtomicOrderingMonotonic, false);
536            }
537         }
538         ac_build_endif(&ctx->ac, 5222);
539         ac_build_endif(&ctx->ac, 5221);
540      }
541
542      tmp = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), i32_4, "");
543      ac_build_ifcc(&ctx->ac, tmp, 5225);
544      {
545         tmp = LLVMBuildAdd(builder, ac_get_thread_id(&ctx->ac), scratch_emit_basev, "");
546         tmp = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tmp);
547         LLVMBuildStore(builder, emit_vgpr, tmp);
548      }
549      ac_build_endif(&ctx->ac, 5225);
550   }
551   ac_build_endif(&ctx->ac, 5200);
552
553   /* Determine the workgroup-relative per-thread / primitive offset into
554    * the streamout buffers */
555   struct ac_wg_scan primemit_scan[4] = {};
556
557   if (isgs) {
558      for (unsigned stream = 0; stream < 4; ++stream) {
559         if (!info->num_stream_output_components[stream])
560            continue;
561
562         primemit_scan[stream].stage = ctx->stage;
563         primemit_scan[stream].enable_exclusive = true;
564         primemit_scan[stream].op = nir_op_iadd;
565         primemit_scan[stream].src = nggso->prim_enable[stream];
566         primemit_scan[stream].scratch = ac_build_gep0(
567            &ctx->ac, ctx->gs_ngg_scratch, LLVMConstInt(ctx->ac.i32, 12 + 8 * stream, false));
568         primemit_scan[stream].waveidx = get_wave_id_in_tg(ctx);
569         primemit_scan[stream].numwaves = get_tgsize(ctx);
570         if (ctx->stage == MESA_SHADER_GEOMETRY) {
571            /* ngg_subgroup_size is only the input size. GS can always generate up to 256 vertices. */
572            primemit_scan[stream].maxwaves = DIV_ROUND_UP(256, ctx->ac.wave_size);
573         } else {
574            primemit_scan[stream].maxwaves = DIV_ROUND_UP(ctx->screen->ngg_subgroup_size,
575                                                          ctx->ac.wave_size);
576         }
577         ac_build_wg_scan_top(&ctx->ac, &primemit_scan[stream]);
578      }
579   }
580
581   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
582   ac_build_s_barrier(&ctx->ac, ctx->stage);
583
584   /* Fetch the per-buffer offsets and per-stream emit counts in all waves. */
585   LLVMValueRef wgoffset_dw[4] = {};
586
587   {
588      LLVMValueRef scratch_vgpr;
589
590      tmp = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, ac_get_thread_id(&ctx->ac));
591      scratch_vgpr = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
592
593      for (unsigned buffer = 0; buffer < 4; ++buffer) {
594         if (stream_for_buffer[buffer] >= 0) {
595            wgoffset_dw[buffer] =
596               ac_build_readlane(&ctx->ac, scratch_vgpr,
597                                 LLVMConstInt(ctx->ac.i32, scratch_offset_base + buffer, false));
598         }
599      }
600
601      for (unsigned stream = 0; stream < 4; ++stream) {
602         if (info->num_stream_output_components[stream]) {
603            nggso->emit[stream] =
604               ac_build_readlane(&ctx->ac, scratch_vgpr,
605                                 LLVMConstInt(ctx->ac.i32, scratch_emit_base + stream, false));
606         }
607      }
608   }
609
610   /* Write out primitive data */
611   for (unsigned stream = 0; stream < 4; ++stream) {
612      if (!info->num_stream_output_components[stream])
613         continue;
614
615      if (isgs) {
616         ac_build_wg_scan_bottom(&ctx->ac, &primemit_scan[stream]);
617      } else {
618         primemit_scan[stream].result_exclusive = tid;
619      }
620
621      tmp = LLVMBuildICmp(builder, LLVMIntULT, primemit_scan[stream].result_exclusive,
622                          nggso->emit[stream], "");
623      tmp = LLVMBuildAnd(builder, tmp, nggso->prim_enable[stream], "");
624      ac_build_ifcc(&ctx->ac, tmp, 5240);
625      {
626         LLVMValueRef offset_vtx =
627            LLVMBuildMul(builder, primemit_scan[stream].result_exclusive, nggso->num_vertices, "");
628
629         for (unsigned i = 0; i < max_num_vertices; ++i) {
630            tmp = LLVMBuildICmp(builder, LLVMIntULT, LLVMConstInt(ctx->ac.i32, i, false),
631                                nggso->num_vertices, "");
632            ac_build_ifcc(&ctx->ac, tmp, 5241);
633            build_streamout_vertex(ctx, so_buffer, wgoffset_dw, stream, offset_vtx,
634                                   nggso->vertices[i]);
635            ac_build_endif(&ctx->ac, 5241);
636            offset_vtx = LLVMBuildAdd(builder, offset_vtx, ctx->ac.i32_1, "");
637         }
638      }
639      ac_build_endif(&ctx->ac, 5240);
640   }
641}
642
643/* LDS layout of ES vertex data for NGG culling. */
644enum
645{
646   /* Byte 0: Boolean ES thread accepted (unculled) flag.
647    * Byte 1: New ES thread ID, loaded by GS to prepare the prim export value.
648    * Byte 2: TES rel patch ID
649    * Byte 3: 8-bit clip distance mask: 1 means the clip distance is negative.
650    *         The mask from all vertices is AND'ed. If the result is non-zero,
651    *         the primitive is culled.
652    */
653   lds_byte0_accept_flag = 0,
654   lds_byte1_new_thread_id,
655   lds_byte2_tes_rel_patch_id,
656   lds_byte3_clipdist_neg_mask,
657
658   lds_packed_data = 0, /* lds_byteN_... */
659   lds_pos_cull_x_div_w,
660   lds_pos_cull_y_div_w,
661   lds_pos_cull_w,
662
663   lds_pos_x = lds_packed_data + 1,
664   lds_pos_y,
665   lds_pos_z,
666   lds_pos_w,
667   /* If VS: */
668   lds_vertex_id,
669   lds_instance_id, /* optional */
670   /* If TES: */
671   lds_tes_u = lds_vertex_id,
672   lds_tes_v = lds_instance_id,
673   lds_tes_patch_id, /* optional */
674};
675
676static LLVMValueRef si_build_gep_i8_var(struct si_shader_context *ctx, LLVMValueRef ptr,
677                                        LLVMValueRef index)
678{
679   LLVMTypeRef pi8 = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS);
680
681   return LLVMBuildGEP(ctx->ac.builder, LLVMBuildPointerCast(ctx->ac.builder, ptr, pi8, ""), &index,
682                       1, "");
683}
684
685static LLVMValueRef si_build_gep_i8(struct si_shader_context *ctx, LLVMValueRef ptr,
686                                    unsigned byte_index)
687{
688   assert(byte_index < 4);
689   return si_build_gep_i8_var(ctx, ptr, LLVMConstInt(ctx->ac.i32, byte_index, 0));
690}
691
692static unsigned ngg_nogs_vertex_size(struct si_shader *shader)
693{
694   unsigned lds_vertex_size = 0;
695
696   /* The edgeflag is always stored in the last element that's also
697    * used for padding to reduce LDS bank conflicts. */
698   if (si_shader_uses_streamout(shader))
699      lds_vertex_size = 4 * shader->selector->info.num_outputs + 1;
700   if (gfx10_ngg_writes_user_edgeflags(shader))
701      lds_vertex_size = MAX2(lds_vertex_size, 1);
702
703   /* LDS size for passing data from GS to ES.
704    * GS stores Primitive IDs into LDS at the address corresponding
705    * to the ES thread of the provoking vertex. All ES threads
706    * load and export PrimitiveID for their thread.
707    */
708   if (shader->selector->stage == MESA_SHADER_VERTEX && shader->key.ge.mono.u.vs_export_prim_id)
709      lds_vertex_size = MAX2(lds_vertex_size, 1);
710
711   if (shader->key.ge.opt.ngg_culling) {
712      if (shader->selector->stage == MESA_SHADER_VERTEX) {
713         STATIC_ASSERT(lds_instance_id + 1 == 7);
714         lds_vertex_size = MAX2(lds_vertex_size, 7);
715      } else {
716         assert(shader->selector->stage == MESA_SHADER_TESS_EVAL);
717
718         if (shader->selector->info.uses_primid || shader->key.ge.mono.u.vs_export_prim_id) {
719            STATIC_ASSERT(lds_tes_patch_id + 2 == 9); /* +1 for LDS padding */
720            lds_vertex_size = MAX2(lds_vertex_size, 9);
721         } else {
722            STATIC_ASSERT(lds_tes_v + 1 == 7);
723            lds_vertex_size = MAX2(lds_vertex_size, 7);
724         }
725      }
726   }
727
728   return lds_vertex_size;
729}
730
731/**
732 * Returns an `[N x i32] addrspace(LDS)*` pointing at contiguous LDS storage
733 * for the vertex outputs.
734 */
735static LLVMValueRef ngg_nogs_vertex_ptr(struct si_shader_context *ctx, LLVMValueRef vtxid)
736{
737   /* The extra dword is used to avoid LDS bank conflicts. */
738   unsigned vertex_size = ngg_nogs_vertex_size(ctx->shader);
739   LLVMTypeRef ai32 = LLVMArrayType(ctx->ac.i32, vertex_size);
740   LLVMTypeRef pai32 = LLVMPointerType(ai32, AC_ADDR_SPACE_LDS);
741   LLVMValueRef tmp = LLVMBuildBitCast(ctx->ac.builder, ctx->esgs_ring, pai32, "");
742   return LLVMBuildGEP(ctx->ac.builder, tmp, &vtxid, 1, "");
743}
744
745static LLVMValueRef si_insert_input_v4i32(struct si_shader_context *ctx, LLVMValueRef ret,
746                                          struct ac_arg param, unsigned return_index)
747{
748   LLVMValueRef v = ac_get_arg(&ctx->ac, param);
749
750   for (unsigned i = 0; i < 4; i++) {
751      ret = LLVMBuildInsertValue(ctx->ac.builder, ret, ac_llvm_extract_elem(&ctx->ac, v, i),
752                                 return_index + i, "");
753   }
754   return ret;
755}
756
757static void load_vertex_counts(struct si_shader_context *ctx, LLVMValueRef lds,
758                               unsigned max_waves, LLVMValueRef tid,
759                               LLVMValueRef *total_count,
760                               LLVMValueRef *prefix_sum)
761{
762   LLVMBuilderRef builder = ctx->ac.builder;
763   LLVMValueRef i8vec4_lane = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, "");
764   unsigned num_i8vec4 = DIV_ROUND_UP(max_waves, 4);
765
766   /* If all threads loaded the vertex counts, it would cause many LDS bank conflicts
767    * and the performance could decrease up to WaveSize times (32x or 64x).
768    *
769    * Therefore, only load the i-th tuple of vertex counts in the i-th thread. Other threads will
770    * get them through readlane. 4 8-bit vertex counts are loaded per thread.
771    */
772   ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntULT, tid,
773                                         LLVMConstInt(ctx->ac.i32, num_i8vec4, 0), ""), 17771);
774   LLVMBuildStore(builder, LLVMBuildLoad2(builder, ctx->ac.i32, ac_build_gep0(&ctx->ac, lds, tid), ""), i8vec4_lane);
775   ac_build_endif(&ctx->ac, 17771);
776
777   /* Compute the number of ES waves. */
778   LLVMValueRef num_waves = get_tgsize(ctx);
779
780   /* Compute a byte mask where each byte is either 0 or 0xff depending on whether the wave
781    * exists. We need the mask to clear uninitialized bytes in LDS and to compute the prefix sum.
782    *
783    * 8 waves: valid_mask = ~0ull >> (64 - num_waves * 8)
784    * 4 waves: valid_mask = ~0 >> (32 - num_waves * 8)
785    */
786   LLVMValueRef num_waves8 = LLVMBuildShl(builder, num_waves, LLVMConstInt(ctx->ac.i32, 3, 0), "");
787   LLVMValueRef valid_mask;
788
789   if (max_waves > 4) {
790      LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 64, 0),
791                                                 num_waves8, "");
792      valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i64, ~0ull, 0),
793                                 LLVMBuildZExt(builder, num_waves8_rev, ctx->ac.i64, ""), "");
794   } else {
795      LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 32, 0),
796                                                 num_waves8, "");
797      valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i32, ~0, 0), num_waves8_rev, "");
798   }
799
800   /* Compute a byte mask where bytes below wave_id are 0xff, else they are 0.
801    *
802    * prefix_mask = ~(~0 << (wave_id * 8))
803    */
804   LLVMTypeRef type = max_waves > 4 ? ctx->ac.i64 : ctx->ac.i32;
805   LLVMValueRef wave_id8 = LLVMBuildShl(builder, get_wave_id_in_tg(ctx),
806                                        LLVMConstInt(ctx->ac.i32, 3, 0), "");
807   LLVMValueRef prefix_mask =
808      LLVMBuildNot(builder, LLVMBuildShl(builder, LLVMConstInt(type, ~0ull, 0),
809                                         LLVMBuildZExt(builder, wave_id8, type, ""), ""), "");
810
811   /* Compute the total vertex count and the vertex count of previous waves (prefix). */
812   *total_count = ctx->ac.i32_0;
813   *prefix_sum = ctx->ac.i32_0;
814
815   for (unsigned i = 0; i < num_i8vec4; i++) {
816      LLVMValueRef i8vec4;
817
818      i8vec4 = ac_build_readlane_no_opt_barrier(&ctx->ac, LLVMBuildLoad2(builder, ctx->ac.i32, i8vec4_lane, ""),
819                                                LLVMConstInt(ctx->ac.i32, i, 0));
820      /* Inactive waves have uninitialized vertex counts. Set them to 0 using this. */
821      i8vec4 = LLVMBuildAnd(builder, i8vec4,
822                            ac_unpack_param(&ctx->ac, valid_mask, 32 * i, 32), "");
823      /* Compute the sum of all i8vec4 components and add it to the result. */
824      *total_count = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32,
825                                        (LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *total_count},
826                                        3, AC_FUNC_ATTR_READNONE);
827      ac_set_range_metadata(&ctx->ac, *total_count, 0, 64*4 + 1); /* the result is at most 64*4 */
828
829      /* Compute the sum of the vertex counts of all previous waves. */
830      i8vec4 = LLVMBuildAnd(builder, i8vec4,
831                                ac_unpack_param(&ctx->ac, prefix_mask, 32 * i, 32), "");
832      *prefix_sum = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32,
833                                       (LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *prefix_sum},
834                                       3, AC_FUNC_ATTR_READNONE);
835      ac_set_range_metadata(&ctx->ac, *prefix_sum, 0, 64*4 + 1); /* the result is at most 64*4 */
836   }
837   *total_count = ac_build_readlane_no_opt_barrier(&ctx->ac, *total_count, NULL);
838}
839
840/**
841 * Given a total thread count, update total and per-wave thread counts in input SGPRs
842 * and return the per-wave thread count.
843 *
844 * \param new_num_threads    Total thread count on the input, per-wave thread count on the output.
845 * \param tg_info            tg_info SGPR value
846 * \param tg_info_num_bits   the bit size of thread count field in tg_info
847 * \param tg_info_shift      the bit offset of the thread count field in tg_info
848 * \param wave_info          merged_wave_info SGPR value
849 * \param wave_info_num_bits the bit size of thread count field in merged_wave_info
850 * \param wave_info_shift    the bit offset of the thread count field in merged_wave_info
851 */
852static void update_thread_counts(struct si_shader_context *ctx, LLVMValueRef *new_num_threads,
853                                 LLVMValueRef *tg_info, unsigned tg_info_num_bits,
854                                 unsigned tg_info_shift, LLVMValueRef *wave_info,
855                                 unsigned wave_info_num_bits, unsigned wave_info_shift)
856{
857   LLVMBuilderRef builder = ctx->ac.builder;
858
859   /* Update the total thread count. */
860   unsigned tg_info_mask = ~(u_bit_consecutive(0, tg_info_num_bits) << tg_info_shift);
861   *tg_info = LLVMBuildAnd(builder, *tg_info, LLVMConstInt(ctx->ac.i32, tg_info_mask, 0), "");
862   *tg_info = LLVMBuildOr(
863      builder, *tg_info,
864      LLVMBuildShl(builder, *new_num_threads, LLVMConstInt(ctx->ac.i32, tg_info_shift, 0), ""), "");
865
866   /* Update the per-wave thread count. */
867   LLVMValueRef prev_threads = LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
868                                            LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, 0), "");
869   *new_num_threads = LLVMBuildSub(builder, *new_num_threads, prev_threads, "");
870   *new_num_threads = ac_build_imax(&ctx->ac, *new_num_threads, ctx->ac.i32_0);
871   *new_num_threads =
872      ac_build_imin(&ctx->ac, *new_num_threads, LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, 0));
873   unsigned wave_info_mask = ~(u_bit_consecutive(0, wave_info_num_bits) << wave_info_shift);
874   *wave_info = LLVMBuildAnd(builder, *wave_info, LLVMConstInt(ctx->ac.i32, wave_info_mask, 0), "");
875   *wave_info = LLVMBuildOr(
876      builder, *wave_info,
877      LLVMBuildShl(builder, *new_num_threads, LLVMConstInt(ctx->ac.i32, wave_info_shift, 0), ""),
878      "");
879}
880
881static void gfx10_build_primitive_accepted(struct ac_llvm_context *ac, LLVMValueRef accepted,
882                                           void *userdata)
883{
884   struct si_shader_context *ctx = container_of(ac, struct si_shader_context, ac);
885   LLVMValueRef *params = (LLVMValueRef *)userdata;
886   LLVMValueRef gs_accepted = params[0];
887   LLVMValueRef *gs_vtxptr = (LLVMValueRef *)params[1];
888
889   unsigned num_vertices;
890   ngg_get_vertices_per_prim(ctx, &num_vertices);
891
892   ac_build_ifcc(&ctx->ac, accepted, 0);
893   LLVMBuildStore(ctx->ac.builder, ctx->ac.i32_1, gs_accepted);
894
895   if (gs_vtxptr) {
896      for (unsigned vtx = 0; vtx < num_vertices; vtx++) {
897         LLVMBuildStore(ctx->ac.builder, ctx->ac.i8_1,
898                        si_build_gep_i8(ctx, gs_vtxptr[vtx], lds_byte0_accept_flag));
899      }
900   }
901   ac_build_endif(&ctx->ac, 0);
902}
903
904static void add_clipdist_bit(struct si_shader_context *ctx, LLVMValueRef distance, unsigned i,
905                             LLVMValueRef *packed_data)
906{
907   LLVMValueRef neg = LLVMBuildFCmp(ctx->ac.builder, LLVMRealOLT, distance, ctx->ac.f32_0, "");
908   neg = LLVMBuildZExt(ctx->ac.builder, neg, ctx->ac.i32, "");
909   /* Put the negative distance flag into lds_byte3_clipdist_neg_mask. */
910   neg = LLVMBuildShl(ctx->ac.builder, neg, LLVMConstInt(ctx->ac.i32, 24 + i, 0), "");
911   *packed_data = LLVMBuildOr(ctx->ac.builder, *packed_data, neg, "");
912}
913
914static bool add_clipdist_bits_for_clipvertex(struct si_shader_context *ctx,
915                                             unsigned clipdist_enable,
916                                             LLVMValueRef clipvertex[4],
917                                             LLVMValueRef *packed_data)
918{
919   struct ac_export_args clipdist[2];
920   bool added = false;
921
922   si_llvm_clipvertex_to_clipdist(ctx, clipdist, clipvertex);
923
924   for (unsigned j = 0; j < 8; j++) {
925      if (!(clipdist_enable & BITFIELD_BIT(j)))
926         continue;
927
928      LLVMValueRef distance = clipdist[j / 4].out[j % 4];
929      add_clipdist_bit(ctx, distance, j, packed_data);
930      added = true;
931   }
932   return added;
933}
934
935static void cull_primitive(struct si_shader_context *ctx,
936                           LLVMValueRef pos[3][4], LLVMValueRef clipdist_accepted,
937                           LLVMValueRef out_prim_accepted, LLVMValueRef gs_vtxptr_accept[3])
938{
939   struct si_shader *shader = ctx->shader;
940   LLVMBuilderRef builder = ctx->ac.builder;
941
942   LLVMValueRef vp_scale[2] = {}, vp_translate[2] = {}, small_prim_precision = NULL;
943   LLVMValueRef clip_half_line_width[2] = {};
944
945   /* Load the viewport state for small prim culling. */
946   bool prim_is_lines = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_LINES;
947   LLVMValueRef ptr = ac_get_arg(&ctx->ac, ctx->small_prim_cull_info);
948   /* Lines will always use the non-AA viewport transformation. */
949   LLVMValueRef vp = ac_build_load_to_sgpr(&ctx->ac, ptr,
950                                           prim_is_lines ? ctx->ac.i32_1 : ctx->ac.i32_0);
951   vp = LLVMBuildBitCast(builder, vp, ctx->ac.v4f32, "");
952   vp_scale[0] = ac_llvm_extract_elem(&ctx->ac, vp, 0);
953   vp_scale[1] = ac_llvm_extract_elem(&ctx->ac, vp, 1);
954   vp_translate[0] = ac_llvm_extract_elem(&ctx->ac, vp, 2);
955   vp_translate[1] = ac_llvm_extract_elem(&ctx->ac, vp, 3);
956
957   /* Execute culling code. */
958   struct ac_cull_options options = {};
959   options.cull_view_xy = true;
960   options.cull_w = true;
961
962   if (prim_is_lines) {
963      ptr = LLVMBuildPointerCast(ctx->ac.builder, ptr,
964                                 LLVMPointerType(ctx->ac.v2f32, AC_ADDR_SPACE_CONST_32BIT), "");
965      LLVMValueRef terms = ac_build_load_to_sgpr(&ctx->ac, ptr, LLVMConstInt(ctx->ac.i32, 4, 0));
966      terms = LLVMBuildBitCast(builder, terms, ctx->ac.v2f32, "");
967      clip_half_line_width[0] = ac_llvm_extract_elem(&ctx->ac, terms, 0);
968      clip_half_line_width[1] = ac_llvm_extract_elem(&ctx->ac, terms, 1);
969      small_prim_precision = GET_FIELD(ctx, GS_STATE_SMALL_PRIM_PRECISION_NO_AA);
970
971      options.num_vertices = 2;
972      options.cull_small_prims = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_SMALL_LINES_DIAMOND_EXIT;
973
974      assert(!(shader->key.ge.opt.ngg_culling & SI_NGG_CULL_BACK_FACE));
975      assert(!(shader->key.ge.opt.ngg_culling & SI_NGG_CULL_FRONT_FACE));
976   } else {
977      /* Get the small prim filter precision. */
978      small_prim_precision = GET_FIELD(ctx, GS_STATE_SMALL_PRIM_PRECISION);
979
980      options.num_vertices = 3;
981      options.cull_front = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_FRONT_FACE;
982      options.cull_back = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_BACK_FACE;
983      options.cull_small_prims = true; /* this would only be false with conservative rasterization */
984      options.cull_zero_area = options.cull_front || options.cull_back;
985   }
986
987   /* Extract the small prim precision. */
988   small_prim_precision =
989      LLVMBuildOr(builder, small_prim_precision, LLVMConstInt(ctx->ac.i32, 0x70, 0), "");
990   small_prim_precision =
991      LLVMBuildShl(builder, small_prim_precision, LLVMConstInt(ctx->ac.i32, 23, 0), "");
992   small_prim_precision = LLVMBuildBitCast(builder, small_prim_precision, ctx->ac.f32, "");
993
994   /* Tell ES threads whether their vertex survived. */
995   LLVMValueRef params[] = {
996      out_prim_accepted,
997      (void*)gs_vtxptr_accept,
998   };
999   ac_cull_primitive(&ctx->ac, pos, clipdist_accepted, vp_scale, vp_translate,
1000                     small_prim_precision, clip_half_line_width,
1001                     &options, gfx10_build_primitive_accepted, params);
1002}
1003
1004/**
1005 * Cull primitives for NGG VS or TES, then compact vertices, which happens
1006 * before the VS or TES main function. Return values for the main function.
1007 * Also return the position, which is passed to the shader as an input,
1008 * so that we don't compute it twice.
1009 */
1010void gfx10_ngg_culling_build_end(struct si_shader_context *ctx)
1011{
1012   struct si_shader *shader = ctx->shader;
1013   struct si_shader_selector *sel = shader->selector;
1014   struct si_shader_info *info = &sel->info;
1015   LLVMBuilderRef builder = ctx->ac.builder;
1016   LLVMValueRef *addrs = ctx->abi.outputs;
1017   unsigned max_waves = DIV_ROUND_UP(ctx->screen->ngg_subgroup_size, ctx->ac.wave_size);
1018
1019   assert(shader->key.ge.opt.ngg_culling);
1020   assert(shader->key.ge.as_ngg);
1021   assert(sel->stage == MESA_SHADER_VERTEX ||
1022          (sel->stage == MESA_SHADER_TESS_EVAL && !shader->key.ge.as_es));
1023
1024   LLVMValueRef es_vtxptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1025   LLVMValueRef packed_data = ctx->ac.i32_0;
1026   LLVMValueRef position[4] = {};
1027   unsigned pos_index = 0;
1028   unsigned clip_plane_enable = SI_NGG_CULL_GET_CLIP_PLANE_ENABLE(shader->key.ge.opt.ngg_culling);
1029   unsigned clipdist_enable = (sel->info.clipdist_mask & clip_plane_enable) | sel->info.culldist_mask;
1030   bool has_clipdist_mask = false;
1031
1032   for (unsigned i = 0; i < info->num_outputs; i++) {
1033      LLVMValueRef clipvertex[4];
1034      unsigned base;
1035
1036      switch (info->output_semantic[i]) {
1037      case VARYING_SLOT_POS:
1038         /* If we are going to cull everything (rasterizer_discard), discard
1039          * the position. This is useful for analyzing maximum theoretical
1040          * performance without VS input loads.
1041          */
1042         if (shader->key.ge.opt.ngg_culling & SI_NGG_CULL_FRONT_FACE &&
1043             shader->key.ge.opt.ngg_culling & SI_NGG_CULL_BACK_FACE) {
1044            for (unsigned j = 0; j < 4; j++)
1045               LLVMBuildStore(builder, LLVMGetUndef(ctx->ac.f32), addrs[4 * i + j]);
1046            break;
1047         }
1048
1049         pos_index = i;
1050         for (unsigned j = 0; j < 4; j++) {
1051            position[j] = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.f32, addrs[4 * i + j], "");
1052         }
1053
1054         /* Store Position.W into LDS. */
1055         LLVMBuildStore(
1056            builder, ac_to_integer(&ctx->ac, position[3]),
1057            ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_pos_cull_w, 0)));
1058
1059         /* Store Position.XY / W into LDS. */
1060         for (unsigned chan = 0; chan < 2; chan++) {
1061            LLVMValueRef val = ac_build_fdiv(&ctx->ac, position[chan], position[3]);
1062            LLVMBuildStore(
1063               builder, ac_to_integer(&ctx->ac, val),
1064               ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_pos_cull_x_div_w + chan, 0)));
1065         }
1066         break;
1067
1068      case VARYING_SLOT_CLIP_DIST0:
1069      case VARYING_SLOT_CLIP_DIST1:
1070         base = info->output_semantic[i] == VARYING_SLOT_CLIP_DIST1 ? 4 : 0;
1071
1072         for (unsigned j = 0; j < 4; j++) {
1073            unsigned index = base + j;
1074
1075            if (!(clipdist_enable & BITFIELD_BIT(index)))
1076               continue;
1077
1078            LLVMValueRef distance = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.f32, addrs[4 * i + j], "");
1079            add_clipdist_bit(ctx, distance, index, &packed_data);
1080            has_clipdist_mask = true;
1081         }
1082         break;
1083
1084      case VARYING_SLOT_CLIP_VERTEX:
1085         for (unsigned j = 0; j < 4; j++)
1086            clipvertex[j] = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.f32, addrs[4 * i + j], "");
1087
1088         if (add_clipdist_bits_for_clipvertex(ctx, clipdist_enable, clipvertex, &packed_data))
1089            has_clipdist_mask = true;
1090         break;
1091      }
1092   }
1093
1094   if (clip_plane_enable && !sel->info.clipdist_mask) {
1095      /* When clip planes are enabled and there are no clip distance outputs,
1096       * we should use user clip planes and cull against the position.
1097       */
1098      assert(!has_clipdist_mask);
1099      if (add_clipdist_bits_for_clipvertex(ctx, clipdist_enable, position, &packed_data))
1100         has_clipdist_mask = true;
1101   }
1102
1103   /* Initialize the packed data. */
1104   LLVMBuildStore(
1105      builder, packed_data,
1106      ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_packed_data, 0)));
1107   ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
1108
1109   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1110   ac_build_s_barrier(&ctx->ac, ctx->stage);
1111
1112   LLVMValueRef tid = ac_get_thread_id(&ctx->ac);
1113
1114   unsigned num_vertices;
1115   ngg_get_vertices_per_prim(ctx, &num_vertices);
1116
1117   /* The hardware requires that there are no holes between unculled vertices,
1118    * which means we have to pack ES threads, i.e. reduce the ES thread count
1119    * and move ES input VGPRs to lower threads. The upside is that varyings
1120    * are only fetched and computed for unculled vertices.
1121    *
1122    * Vertex compaction:
1123    *
1124    * Part 1: Store the surviving vertex count for each wave in LDS.
1125    *   - The GS culling code notifies ES threads which vertices were accepted.
1126    *   - Barrier
1127    *   - ES threads will compute the vertex count and store it in LDS.
1128    * - Barrier
1129    * - Each wave loads the vertex counts from LDS.
1130    *
1131    * Part 2: Compact ES threads:
1132    * - Compute the prefix sum for each surviving vertex. This is the new thread ID
1133    *   of the vertex.
1134    * - Write input VGPRs and vertex positions for each surviving vertex into the LDS
1135    *   address of the new thread ID.
1136    * - Now kill all waves that have inactive threads.
1137    * - Barrier
1138    * - Update vertex indices and null flag in the GS input VGPRs.
1139    *
1140    * Part 3: Update inputs GPRs
1141    * - For all waves, update per-wave thread counts in input SGPRs.
1142    * - In ES threads, update the ES input VGPRs (VertexID, InstanceID, TES inputs).
1143    */
1144
1145   LLVMValueRef vtxindex[3];
1146   for (unsigned i = 0; i < num_vertices; ++i)
1147      vtxindex[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[i / 2], (i & 1) * 16, 16);
1148
1149   LLVMValueRef gs_vtxptr[3];
1150   for (unsigned i = 0; i < num_vertices; i++)
1151      gs_vtxptr[i] = ngg_nogs_vertex_ptr(ctx, vtxindex[i]);
1152
1153   es_vtxptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1154
1155   /* Adding these optimization barriers improves the generated code as follows. Crazy right?
1156    *
1157    * - s_mov_b32 s4, 0xffff
1158    * - v_lshrrev_b32_e32 v10, 16, v0
1159    * - v_and_b32_e32 v12, s4, v0
1160    * - v_and_b32_e32 v11, s4, v1
1161    *   s_bfe_u32 s4, s3, 0x80008
1162    * - s_mov_b64 s[8:9], 0
1163    * - v_mul_u32_u24_e32 v0, 28, v10
1164    * - v_mul_u32_u24_e32 v9, 28, v12
1165    * - v_mul_u32_u24_e32 v1, 28, v11
1166    * + v_mov_b32_e32 v11, 28
1167    *   v_cmp_gt_u32_e32 vcc, s4, v2
1168    * + s_mov_b64 s[8:9], 0
1169    *   s_waitcnt lgkmcnt(0)
1170    *   s_barrier
1171    * + v_mul_u32_u24_sdwa v10, v0, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_0 src1_sel:DWORD
1172    * + v_mul_u32_u24_sdwa v23, v0, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
1173    * + v_mul_u32_u24_sdwa v0, v1, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_0 src1_sel:DWORD
1174    *   s_and_saveexec_b64 s[44:45], vcc
1175    *   s_cbranch_execz BB2_8
1176    * - v_mul_u32_u24_e32 v16, 28, v12
1177    * - v_mul_u32_u24_e32 v17, 28, v11
1178    * - v_mul_u32_u24_e32 v18, 28, v10
1179    */
1180   for (unsigned i = 0; i < num_vertices; i++)
1181      ac_build_optimization_barrier(&ctx->ac, &gs_vtxptr[i], false);
1182
1183   LLVMValueRef gs_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i32, "");
1184
1185   /* Do culling in GS threads. */
1186   ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 16002);
1187   {
1188      /* Load positions. */
1189      LLVMValueRef pos[3][4] = {};
1190      LLVMValueRef clipdist_neg_mask = NULL;
1191
1192      for (unsigned vtx = 0; vtx < num_vertices; vtx++) {
1193         for (unsigned chan = 0; chan < 4; chan++) {
1194            unsigned index;
1195            if (chan == 0 || chan == 1)
1196               index = lds_pos_cull_x_div_w + chan;
1197            else if (chan == 3)
1198               index = lds_pos_cull_w;
1199            else
1200               continue;
1201
1202            LLVMValueRef addr =
1203               ac_build_gep0(&ctx->ac, gs_vtxptr[vtx], LLVMConstInt(ctx->ac.i32, index, 0));
1204            pos[vtx][chan] = LLVMBuildLoad(builder, addr, "");
1205            pos[vtx][chan] = ac_to_float(&ctx->ac, pos[vtx][chan]);
1206         }
1207
1208         if (has_clipdist_mask) {
1209            /* Load and AND clip distance masks. Each bit means whether that clip distance is
1210             * negative. If all masks are AND'ed and the result is 0, the primitive isn't culled
1211             * by clip distances.
1212             */
1213            LLVMValueRef addr = si_build_gep_i8(ctx, gs_vtxptr[vtx], lds_byte3_clipdist_neg_mask);
1214            LLVMValueRef mask = LLVMBuildLoad2(builder, ctx->ac.i8, addr, "");
1215            if (!clipdist_neg_mask)
1216               clipdist_neg_mask = mask;
1217            else
1218               clipdist_neg_mask = LLVMBuildAnd(builder, clipdist_neg_mask, mask, "");
1219         }
1220      }
1221
1222      LLVMValueRef clipdist_accepted =
1223         has_clipdist_mask ? LLVMBuildICmp(builder, LLVMIntEQ, clipdist_neg_mask, ctx->ac.i8_0, "")
1224                           : ctx->ac.i1true;
1225
1226      cull_primitive(ctx, pos, clipdist_accepted, gs_accepted, gs_vtxptr);
1227   }
1228   ac_build_endif(&ctx->ac, 16002);
1229
1230   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1231   ac_build_s_barrier(&ctx->ac, ctx->stage);
1232
1233   gs_accepted = LLVMBuildLoad2(builder, ctx->ac.i32, gs_accepted, "");
1234
1235   LLVMValueRef vertex_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i1, "");
1236   LLVMValueRef vertex_mask = ac_build_alloca(&ctx->ac, ctx->ac.iN_wavemask, "");
1237
1238   /* Convert the per-vertex accept flag to a vertex thread mask, store it in registers. */
1239   ac_build_ifcc(&ctx->ac, si_is_es_thread(ctx), 16007);
1240   {
1241      LLVMValueRef accepted =
1242         LLVMBuildLoad2(builder, ctx->ac.i8, si_build_gep_i8(ctx, es_vtxptr, lds_byte0_accept_flag), "");
1243      accepted = LLVMBuildICmp(builder, LLVMIntNE, accepted, ctx->ac.i8_0, "");
1244      LLVMValueRef mask = ac_get_i1_sgpr_mask(&ctx->ac, accepted);
1245
1246      LLVMBuildStore(builder, accepted, vertex_accepted);
1247      LLVMBuildStore(builder, mask, vertex_mask);
1248   }
1249   ac_build_endif(&ctx->ac, 16007);
1250
1251   /* Store the per-wave vertex count to LDS. Non-ES waves store 0. */
1252   vertex_mask = LLVMBuildLoad2(builder, ctx->ac.iN_wavemask, vertex_mask, "");
1253   ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 16008);
1254   {
1255      LLVMValueRef vertex_count = ac_build_bit_count(&ctx->ac, vertex_mask);
1256      LLVMBuildStore(builder, LLVMBuildTrunc(builder, vertex_count, ctx->ac.i8, ""),
1257                     si_build_gep_i8_var(ctx, ctx->gs_ngg_scratch, get_wave_id_in_tg(ctx)));
1258   }
1259   ac_build_endif(&ctx->ac, 16008);
1260
1261   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1262   ac_build_s_barrier(&ctx->ac, ctx->stage);
1263
1264   /* Load the vertex masks and compute the new ES thread count. */
1265   LLVMValueRef new_num_es_threads, prefix_sum, kill_wave;
1266   load_vertex_counts(ctx, ctx->gs_ngg_scratch, max_waves, tid, &new_num_es_threads,
1267                      &prefix_sum);
1268
1269   bool uses_instance_id = ctx->stage == MESA_SHADER_VERTEX &&
1270                           (sel->info.uses_instanceid ||
1271                            shader->key.ge.part.vs.prolog.instance_divisor_is_one ||
1272                            shader->key.ge.part.vs.prolog.instance_divisor_is_fetched);
1273   bool uses_tes_prim_id = ctx->stage == MESA_SHADER_TESS_EVAL &&
1274                           (sel->info.uses_primid || shader->key.ge.mono.u.vs_export_prim_id);
1275
1276   /* ES threads compute their prefix sum, which is the new ES thread ID.
1277    * Then they write the vertex position and input VGPRs into the LDS address
1278    * of the new thread ID. It will be used to load input VGPRs by compacted
1279    * threads.
1280    */
1281   vertex_accepted = LLVMBuildLoad2(builder, ctx->ac.i1, vertex_accepted, "");
1282   ac_build_ifcc(&ctx->ac, vertex_accepted, 16009);
1283   {
1284      /* Add the number of bits set in vertex_mask up to the current thread ID - 1
1285       * to get the prefix sum.
1286       */
1287      prefix_sum = LLVMBuildAdd(builder, prefix_sum, ac_build_mbcnt(&ctx->ac, vertex_mask), "");
1288
1289      LLVMValueRef new_id = prefix_sum;
1290      LLVMValueRef new_vtx = ngg_nogs_vertex_ptr(ctx, new_id);
1291
1292      LLVMBuildStore(builder, LLVMBuildTrunc(builder, new_id, ctx->ac.i8, ""),
1293                     si_build_gep_i8(ctx, es_vtxptr, lds_byte1_new_thread_id));
1294
1295      /* Store Position.XYZW into LDS. */
1296      for (unsigned chan = 0; chan < 4; chan++) {
1297         LLVMBuildStore(
1298            builder, ac_to_integer(&ctx->ac,
1299                                   LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * pos_index + chan], "")),
1300            ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_pos_x + chan, 0)));
1301      }
1302
1303      /* Store VertexID and InstanceID into LDS. ES threads will have to load them
1304       * from LDS after vertex compaction and use them instead of their own
1305       * system values.
1306       */
1307      if (ctx->stage == MESA_SHADER_VERTEX) {
1308         LLVMBuildStore(
1309            builder, ctx->abi.vertex_id,
1310            ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_vertex_id, 0)));
1311         if (uses_instance_id) {
1312            LLVMBuildStore(
1313               builder, ctx->abi.instance_id,
1314               ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_instance_id, 0)));
1315         }
1316      } else {
1317         assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1318         LLVMBuildStore(builder, ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args.tes_u)),
1319                        ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_tes_u, 0)));
1320         LLVMBuildStore(builder, ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args.tes_v)),
1321                        ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_tes_v, 0)));
1322         LLVMBuildStore(builder, LLVMBuildTrunc(builder, ac_get_arg(&ctx->ac, ctx->args.tes_rel_patch_id), ctx->ac.i8, ""),
1323                        si_build_gep_i8(ctx, new_vtx, lds_byte2_tes_rel_patch_id));
1324         if (uses_tes_prim_id) {
1325            LLVMBuildStore(
1326               builder, ac_get_arg(&ctx->ac, ctx->args.tes_patch_id),
1327               ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_tes_patch_id, 0)));
1328         }
1329      }
1330   }
1331   ac_build_endif(&ctx->ac, 16009);
1332
1333   /* If all vertices are culled, set the primitive count to 0, so that all waves are culled here. */
1334   LLVMValueRef num_primitives = ngg_get_prim_cnt(ctx);
1335   num_primitives = LLVMBuildSelect(builder,
1336                                    LLVMBuildICmp(builder, LLVMIntEQ, new_num_es_threads,
1337                                                  ctx->ac.i32_0, ""),
1338                                    ctx->ac.i32_0, num_primitives, "");
1339   /* Kill waves that have inactive threads. */
1340   kill_wave = LLVMBuildICmp(builder, LLVMIntULE,
1341                             ac_build_imax(&ctx->ac, new_num_es_threads, num_primitives),
1342                             LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
1343                                          LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, 0), ""),
1344                             "");
1345   ac_build_ifcc(&ctx->ac, kill_wave, 19202);
1346   {
1347      /* If we are killing wave 0, send that there are no primitives
1348       * in this threadgroup.
1349       */
1350      ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), ctx->ac.i32_0, ctx->ac.i32_0);
1351      ac_build_s_endpgm(&ctx->ac);
1352   }
1353   ac_build_endif(&ctx->ac, 19202);
1354
1355   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1356   ac_build_s_barrier(&ctx->ac, ctx->stage);
1357
1358   /* Send the final vertex and primitive counts. */
1359   ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), new_num_es_threads,
1360                                 ngg_get_prim_cnt(ctx));
1361
1362   /* Update thread counts in SGPRs. */
1363   LLVMValueRef new_gs_tg_info = ac_get_arg(&ctx->ac, ctx->args.gs_tg_info);
1364   LLVMValueRef new_merged_wave_info = ac_get_arg(&ctx->ac, ctx->args.merged_wave_info);
1365
1366   /* This also converts the thread count from the total count to the per-wave count. */
1367   update_thread_counts(ctx, &new_num_es_threads, &new_gs_tg_info, 9, 12, &new_merged_wave_info, 8,
1368                        0);
1369
1370   /* Update vertex indices in VGPR0 (same format as NGG passthrough).
1371    *
1372    * Set the null flag at the beginning (culled), and then
1373    * overwrite it for accepted primitives.
1374    */
1375   LLVMValueRef new_vgpr0 =
1376      ac_build_alloca_init(&ctx->ac, LLVMConstInt(ctx->ac.i32, 1u << 31, 0), "");
1377
1378   /* Get vertex indices after vertex compaction. */
1379   ac_build_ifcc(&ctx->ac, LLVMBuildTrunc(builder, gs_accepted, ctx->ac.i1, ""), 16011);
1380   {
1381      struct ac_ngg_prim prim = {};
1382      prim.num_vertices = num_vertices;
1383      prim.isnull = ctx->ac.i1false;
1384
1385      if (gfx10_edgeflags_have_effect(shader))
1386         prim.edgeflags = ac_pack_edgeflags_for_export(&ctx->ac, &ctx->args);
1387      else
1388         prim.edgeflags = ctx->ac.i32_0;
1389
1390      for (unsigned vtx = 0; vtx < num_vertices; vtx++) {
1391         prim.index[vtx] = LLVMBuildLoad2(
1392            builder, ctx->ac.i8, si_build_gep_i8(ctx, gs_vtxptr[vtx], lds_byte1_new_thread_id), "");
1393         prim.index[vtx] = LLVMBuildZExt(builder, prim.index[vtx], ctx->ac.i32, "");
1394      }
1395
1396      /* Set the new GS input VGPR. */
1397      LLVMBuildStore(builder, ac_pack_prim_export(&ctx->ac, &prim), new_vgpr0);
1398   }
1399   ac_build_endif(&ctx->ac, 16011);
1400
1401   if (gfx10_ngg_export_prim_early(shader))
1402      gfx10_ngg_build_export_prim(ctx, NULL, LLVMBuildLoad2(builder, ctx->ac.i32, new_vgpr0, ""));
1403
1404   /* Prepare LDS addresses of the new ES input VGPRs. */
1405   LLVMValueRef input_vgpr_addresses[4] = {
1406      ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_vertex_id, 0)),
1407      ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_instance_id, 0)),
1408   };
1409   if (ctx->stage == MESA_SHADER_TESS_EVAL) {
1410      input_vgpr_addresses[2] = si_build_gep_i8(ctx, es_vtxptr, lds_byte2_tes_rel_patch_id);
1411      if (uses_tes_prim_id) {
1412         input_vgpr_addresses[3] = ac_build_gep0(&ctx->ac, es_vtxptr,
1413                                                 LLVMConstInt(ctx->ac.i32, lds_tes_patch_id, 0));
1414      }
1415   }
1416
1417   /* Return values for the main function. */
1418   LLVMValueRef ret = ctx->return_value;
1419   LLVMValueRef val;
1420
1421   ret = LLVMBuildInsertValue(ctx->ac.builder, ret, new_gs_tg_info, 2, "");
1422   ret = LLVMBuildInsertValue(ctx->ac.builder, ret, new_merged_wave_info, 3, "");
1423   if (ctx->stage == MESA_SHADER_TESS_EVAL)
1424      ret = si_insert_input_ret(ctx, ret, ctx->args.tess_offchip_offset, 4);
1425   if (ctx->ac.gfx_level >= GFX11)
1426      ret = si_insert_input_ret(ctx, ret, ctx->args.gs_attr_offset, 5);
1427
1428   ret = si_insert_input_ptr(ctx, ret, ctx->internal_bindings, 8 + SI_SGPR_INTERNAL_BINDINGS);
1429   ret = si_insert_input_ptr(ctx, ret, ctx->bindless_samplers_and_images,
1430                             8 + SI_SGPR_BINDLESS_SAMPLERS_AND_IMAGES);
1431   ret = si_insert_input_ptr(ctx, ret, ctx->const_and_shader_buffers,
1432                             8 + SI_SGPR_CONST_AND_SHADER_BUFFERS);
1433   ret = si_insert_input_ptr(ctx, ret, ctx->samplers_and_images, 8 + SI_SGPR_SAMPLERS_AND_IMAGES);
1434   ret = si_insert_input_ptr(ctx, ret, ctx->vs_state_bits, 8 + SI_SGPR_VS_STATE_BITS);
1435   if (ctx->ac.gfx_level >= GFX11)
1436      ret = si_insert_input_ptr(ctx, ret, ctx->gs_attr_address, 8 + GFX9_SGPR_ATTRIBUTE_RING_ADDR);
1437
1438   if (ctx->stage == MESA_SHADER_VERTEX) {
1439      ret = si_insert_input_ptr(ctx, ret, ctx->args.base_vertex, 8 + SI_SGPR_BASE_VERTEX);
1440      ret = si_insert_input_ptr(ctx, ret, ctx->args.draw_id, 8 + SI_SGPR_DRAWID);
1441      ret = si_insert_input_ptr(ctx, ret, ctx->args.start_instance, 8 + SI_SGPR_START_INSTANCE);
1442      ret = si_insert_input_ptr(ctx, ret, ctx->args.vertex_buffers, 8 + GFX9_GS_NUM_USER_SGPR);
1443
1444      for (unsigned i = 0; i < shader->selector->info.num_vbos_in_user_sgprs; i++) {
1445         ret = si_insert_input_v4i32(ctx, ret, ctx->vb_descriptors[i],
1446                                     8 + SI_SGPR_VS_VB_DESCRIPTOR_FIRST + i * 4);
1447      }
1448   } else {
1449      assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1450      ret = si_insert_input_ptr(ctx, ret, ctx->tcs_offchip_layout, 8 + SI_SGPR_TES_OFFCHIP_LAYOUT);
1451      ret = si_insert_input_ptr(ctx, ret, ctx->tes_offchip_addr, 8 + SI_SGPR_TES_OFFCHIP_ADDR);
1452   }
1453
1454   unsigned vgpr;
1455   if (ctx->stage == MESA_SHADER_VERTEX) {
1456      if (shader->selector->info.num_vbos_in_user_sgprs) {
1457         vgpr = 8 + SI_SGPR_VS_VB_DESCRIPTOR_FIRST + shader->selector->info.num_vbos_in_user_sgprs * 4;
1458      } else {
1459         vgpr = 8 + GFX9_GS_NUM_USER_SGPR + 1;
1460      }
1461   } else {
1462      vgpr = 8 + GFX9_GS_NUM_USER_SGPR;
1463   }
1464
1465   val = LLVMBuildLoad2(builder, ctx->ac.i32, new_vgpr0, "");
1466   ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++, "");
1467   vgpr++; /* gs_vtx_offset[1] = offsets of vertices 2-3  */
1468
1469   ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_prim_id, vgpr++);
1470   ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_invocation_id, vgpr++);
1471   vgpr++; /* gs_vtx_offset[2] = offsets of vertices 4-5 */
1472
1473   /* Set the input VPGRs to the corresponding LDS addresses where the VGPR values are
1474    * stored. The VS prolog will load them.
1475    */
1476   if (ctx->stage == MESA_SHADER_VERTEX) {
1477      val = LLVMBuildPtrToInt(builder, input_vgpr_addresses[0], ctx->ac.i32, "");
1478      ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++,
1479                                 ""); /* VGPR5 - VertexID */
1480      vgpr += 2;
1481      if (uses_instance_id) {
1482         val = LLVMBuildPtrToInt(builder, input_vgpr_addresses[1], ctx->ac.i32, "");
1483         ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++,
1484                                    ""); /* VGPR8 - InstanceID */
1485      } else {
1486         vgpr++;
1487      }
1488   } else {
1489      assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1490      unsigned num_vgprs = uses_tes_prim_id ? 4 : 3;
1491      for (unsigned i = 0; i < num_vgprs; i++) {
1492         val = LLVMBuildPtrToInt(builder, input_vgpr_addresses[i], ctx->ac.i32, "");
1493         ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++, "");
1494      }
1495      if (num_vgprs == 3)
1496         vgpr++;
1497   }
1498
1499   /* These two also use LDS. */
1500   if (gfx10_ngg_writes_user_edgeflags(shader) ||
1501       (ctx->stage == MESA_SHADER_VERTEX && shader->key.ge.mono.u.vs_export_prim_id)) {
1502      ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1503      ac_build_s_barrier(&ctx->ac, ctx->stage);
1504   }
1505
1506   ctx->return_value = ret;
1507}
1508
1509/**
1510 * Emit the end of an API VS or TES shader compiled as ESGS shader.
1511 */
1512void gfx10_ngg_build_end(struct si_shader_context *ctx)
1513{
1514   struct si_shader_selector *sel = ctx->shader->selector;
1515   struct si_shader_info *info = &sel->info;
1516   struct si_shader_output_values outputs[PIPE_MAX_SHADER_OUTPUTS];
1517   LLVMBuilderRef builder = ctx->ac.builder;
1518   LLVMValueRef *addrs = ctx->abi.outputs;
1519   LLVMValueRef tmp, tmp2;
1520
1521   assert(!ctx->shader->is_gs_copy_shader);
1522   assert(info->num_outputs <= AC_LLVM_MAX_OUTPUTS);
1523
1524   LLVMValueRef vertex_ptr = NULL;
1525
1526   if (ctx->so.num_outputs || gfx10_ngg_writes_user_edgeflags(ctx->shader))
1527      vertex_ptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1528
1529   for (unsigned i = 0; i < info->num_outputs; i++) {
1530      outputs[i].semantic = info->output_semantic[i];
1531
1532      for (unsigned j = 0; j < 4; j++) {
1533         outputs[i].vertex_streams = info->output_streams[i];
1534
1535         /* TODO: we may store more outputs than streamout needs,
1536          * but streamout performance isn't that important.
1537          */
1538         if (ctx->so.num_outputs) {
1539            tmp = ac_build_gep0(&ctx->ac, vertex_ptr, LLVMConstInt(ctx->ac.i32, 4 * i + j, false));
1540            tmp2 = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i + j], "");
1541            LLVMTypeRef type = ac_to_integer_type(&ctx->ac, ctx->ac.f32);
1542            tmp2 = LLVMBuildBitCast(ctx->ac.builder, tmp2, type, "");
1543            LLVMBuildStore(builder, tmp2, tmp);
1544         }
1545      }
1546
1547      /* Store the edgeflag at the end (if streamout is enabled) */
1548      if (info->output_semantic[i] == VARYING_SLOT_EDGE && gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
1549         LLVMValueRef edgeflag = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i], "");
1550         /* The output is a float, but the hw expects a 1-bit integer. */
1551         edgeflag = LLVMBuildFPToUI(ctx->ac.builder, edgeflag, ctx->ac.i32, "");
1552         edgeflag = ac_build_umin(&ctx->ac, edgeflag, ctx->ac.i32_1);
1553
1554         tmp = LLVMConstInt(ctx->ac.i32, ngg_nogs_vertex_size(ctx->shader) - 1, 0);
1555         tmp = ac_build_gep0(&ctx->ac, vertex_ptr, tmp);
1556         LLVMBuildStore(builder, edgeflag, tmp);
1557      }
1558   }
1559
1560   bool unterminated_es_if_block =
1561      !ctx->so.num_outputs && !gfx10_ngg_writes_user_edgeflags(ctx->shader) &&
1562      !ctx->screen->use_ngg_streamout && /* no query buffer */
1563      (ctx->stage != MESA_SHADER_VERTEX || !ctx->shader->key.ge.mono.u.vs_export_prim_id);
1564
1565   if (!unterminated_es_if_block)
1566      ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
1567
1568   LLVMValueRef is_gs_thread = si_is_gs_thread(ctx);
1569   LLVMValueRef is_es_thread = si_is_es_thread(ctx);
1570   LLVMValueRef vtxindex[3];
1571
1572   if (ctx->shader->key.ge.opt.ngg_culling || gfx10_is_ngg_passthrough(ctx->shader)) {
1573      for (unsigned i = 0; i < 3; ++i)
1574         vtxindex[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[0], 10 * i, 9);
1575   } else {
1576      for (unsigned i = 0; i < 3; ++i)
1577         vtxindex[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[i / 2], (i & 1) * 16, 16);
1578   }
1579
1580   /* Determine the number of vertices per primitive. */
1581   unsigned num_vertices;
1582   LLVMValueRef num_vertices_val = ngg_get_vertices_per_prim(ctx, &num_vertices);
1583
1584   /* Streamout */
1585   LLVMValueRef emitted_prims = NULL;
1586
1587   if (ctx->so.num_outputs) {
1588      assert(!unterminated_es_if_block);
1589
1590      struct ngg_streamout nggso = {};
1591      nggso.num_vertices = num_vertices_val;
1592      nggso.prim_enable[0] = is_gs_thread;
1593
1594      for (unsigned i = 0; i < num_vertices; ++i)
1595         nggso.vertices[i] = ngg_nogs_vertex_ptr(ctx, vtxindex[i]);
1596
1597      build_streamout(ctx, &nggso);
1598      emitted_prims = nggso.emit[0];
1599   }
1600
1601   LLVMValueRef user_edgeflags[3] = {};
1602
1603   if (gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
1604      assert(!unterminated_es_if_block);
1605
1606      /* Streamout already inserted the barrier, so don't insert it again. */
1607      if (!ctx->so.num_outputs) {
1608         ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1609         ac_build_s_barrier(&ctx->ac, ctx->stage);
1610      }
1611
1612      ac_build_ifcc(&ctx->ac, is_gs_thread, 5400);
1613      /* Load edge flags from ES threads and store them into VGPRs in GS threads. */
1614      for (unsigned i = 0; i < num_vertices; i++) {
1615         tmp = ngg_nogs_vertex_ptr(ctx, vtxindex[i]);
1616         tmp2 = LLVMConstInt(ctx->ac.i32, ngg_nogs_vertex_size(ctx->shader) - 1, 0);
1617         tmp = ac_build_gep0(&ctx->ac, tmp, tmp2);
1618         tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
1619         tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
1620
1621         user_edgeflags[i] = ac_build_alloca_init(&ctx->ac, tmp, "");
1622      }
1623      ac_build_endif(&ctx->ac, 5400);
1624   }
1625
1626   /* Copy Primitive IDs from GS threads to the LDS address corresponding
1627    * to the ES thread of the provoking vertex.
1628    */
1629   if (ctx->stage == MESA_SHADER_VERTEX && ctx->shader->key.ge.mono.u.vs_export_prim_id) {
1630      assert(!unterminated_es_if_block);
1631
1632      /* Streamout and edge flags use LDS. Make it idle, so that we can reuse it. */
1633      if (ctx->so.num_outputs || gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
1634         ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1635         ac_build_s_barrier(&ctx->ac, ctx->stage);
1636      }
1637
1638      ac_build_ifcc(&ctx->ac, is_gs_thread, 5400);
1639      /* Extract the PROVOKING_VTX_INDEX field. */
1640      LLVMValueRef provoking_vtx_in_prim = GET_FIELD(ctx, GS_STATE_PROVOKING_VTX_INDEX);
1641
1642      /* provoking_vtx_index = vtxindex[provoking_vtx_in_prim]; */
1643      LLVMValueRef indices = ac_build_gather_values(&ctx->ac, vtxindex, 3);
1644      LLVMValueRef provoking_vtx_index =
1645         LLVMBuildExtractElement(builder, indices, provoking_vtx_in_prim, "");
1646      LLVMValueRef vertex_ptr = ngg_nogs_vertex_ptr(ctx, provoking_vtx_index);
1647
1648      LLVMBuildStore(builder, ac_get_arg(&ctx->ac, ctx->args.gs_prim_id),
1649                     ac_build_gep0(&ctx->ac, vertex_ptr, ctx->ac.i32_0));
1650      ac_build_endif(&ctx->ac, 5400);
1651   }
1652
1653   /* Update query buffer */
1654   if (ctx->screen->use_ngg_streamout && !info->base.vs.blit_sgprs_amd) {
1655      assert(!unterminated_es_if_block);
1656
1657      tmp = GET_FIELD(ctx, GS_STATE_STREAMOUT_QUERY_ENABLED);
1658      tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
1659      ac_build_ifcc(&ctx->ac, tmp, 5029); /* if (STREAMOUT_QUERY_ENABLED) */
1660      tmp = LLVMBuildICmp(builder, LLVMIntEQ, get_wave_id_in_tg(ctx), ctx->ac.i32_0, "");
1661      ac_build_ifcc(&ctx->ac, tmp, 5030);
1662      tmp = LLVMBuildICmp(builder, LLVMIntULE, ac_get_thread_id(&ctx->ac),
1663                          ctx->so.num_outputs ? ctx->ac.i32_1 : ctx->ac.i32_0, "");
1664      ac_build_ifcc(&ctx->ac, tmp, 5031);
1665      {
1666         LLVMValueRef args[] = {
1667            ngg_get_prim_cnt(ctx),
1668            ngg_get_query_buf(ctx),
1669            LLVMConstInt(ctx->ac.i32, 16, false), /* offset of stream[0].generated_primitives */
1670            ctx->ac.i32_0,                        /* soffset */
1671            ctx->ac.i32_0,                        /* cachepolicy */
1672         };
1673
1674         if (ctx->so.num_outputs) {
1675            args[0] = ac_build_writelane(&ctx->ac, args[0], emitted_prims, ctx->ac.i32_1);
1676            args[2] = ac_build_writelane(&ctx->ac, args[2], LLVMConstInt(ctx->ac.i32, 24, false),
1677                                         ctx->ac.i32_1);
1678         }
1679
1680         /* TODO: should this be 64-bit atomics? */
1681         ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5,
1682                            0);
1683      }
1684      ac_build_endif(&ctx->ac, 5031);
1685      ac_build_endif(&ctx->ac, 5030);
1686      ac_build_endif(&ctx->ac, 5029);
1687   }
1688
1689   /* Build the primitive export. */
1690   if (!gfx10_ngg_export_prim_early(ctx->shader)) {
1691      assert(!unterminated_es_if_block);
1692      gfx10_ngg_build_export_prim(ctx, user_edgeflags, NULL);
1693   }
1694
1695   /* Export per-vertex data (positions and parameters). */
1696   if (!unterminated_es_if_block)
1697      ac_build_ifcc(&ctx->ac, is_es_thread, 6002);
1698   {
1699      unsigned i;
1700
1701      /* Unconditionally (re-)load the values for proper SSA form. */
1702      for (i = 0; i < info->num_outputs; i++) {
1703         /* If the NGG cull shader part computed the position, don't
1704          * use the position from the current shader part. Instead,
1705          * load it from LDS.
1706          */
1707         if (info->output_semantic[i] == VARYING_SLOT_POS &&
1708             ctx->shader->key.ge.opt.ngg_culling) {
1709            vertex_ptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1710
1711            for (unsigned j = 0; j < 4; j++) {
1712               tmp = LLVMConstInt(ctx->ac.i32, lds_pos_x + j, 0);
1713               tmp = ac_build_gep0(&ctx->ac, vertex_ptr, tmp);
1714               tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
1715               outputs[i].values[j] = LLVMBuildBitCast(ctx->ac.builder, tmp,
1716                                                       ac_to_float_type(&ctx->ac, ctx->ac.i32), "");
1717            }
1718         } else {
1719            for (unsigned j = 0; j < 4; j++) {
1720               outputs[i].values[j] = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i + j], "");
1721            }
1722         }
1723      }
1724
1725      if (ctx->shader->key.ge.mono.u.vs_export_prim_id) {
1726         outputs[i].semantic = VARYING_SLOT_PRIMITIVE_ID;
1727         outputs[i].vertex_streams = 0;
1728
1729         if (ctx->stage == MESA_SHADER_VERTEX) {
1730            /* Wait for LDS stores to finish. */
1731            ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1732            ac_build_s_barrier(&ctx->ac, ctx->stage);
1733
1734            tmp = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1735            tmp = ac_build_gep0(&ctx->ac, tmp, ctx->ac.i32_0);
1736            outputs[i].values[0] = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
1737         } else {
1738            assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1739            outputs[i].values[0] = si_get_primitive_id(ctx, 0);
1740         }
1741
1742         outputs[i].values[0] = LLVMBuildBitCast(ctx->ac.builder, outputs[i].values[0], ctx->ac.f32, "");
1743         for (unsigned j = 1; j < 4; j++)
1744            outputs[i].values[j] = LLVMGetUndef(ctx->ac.f32);
1745         i++;
1746      }
1747
1748      si_llvm_build_vs_exports(ctx, NULL, outputs, i);
1749   }
1750   ac_build_endif(&ctx->ac, 6002);
1751}
1752
1753static LLVMValueRef ngg_gs_get_vertex_storage(struct si_shader_context *ctx)
1754{
1755   const struct si_shader_selector *sel = ctx->shader->selector;
1756   const struct si_shader_info *info = &sel->info;
1757
1758   LLVMTypeRef elements[2] = {
1759      LLVMArrayType(ctx->ac.i32, 4 * info->num_outputs),
1760      LLVMArrayType(ctx->ac.i8, 4),
1761   };
1762   LLVMTypeRef type = LLVMStructTypeInContext(ctx->ac.context, elements, 2, false);
1763   type = LLVMPointerType(LLVMArrayType(type, 0), AC_ADDR_SPACE_LDS);
1764   return LLVMBuildBitCast(ctx->ac.builder, ctx->gs_ngg_emit, type, "");
1765}
1766
1767/**
1768 * Return a pointer to the LDS storage reserved for the N'th vertex, where N
1769 * is in emit order; that is:
1770 * - at the shader end, N is the threadidx (relative to the entire threadgroup)
1771 * - during vertex emit, i.e. while the API GS shader invocation is running,
1772 *   N = threadidx * gs.vertices_out + emitidx
1773 *
1774 * Goals of the LDS memory layout:
1775 * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
1776 *    in uniform control flow
1777 * 2. Eliminate bank conflicts on read for export if, additionally, there is no
1778 *    culling
1779 * 3. Agnostic to the number of waves (since we don't know it before compiling)
1780 * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
1781 * 5. Avoid wasting memory.
1782 *
1783 * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
1784 * layout, elimination of bank conflicts requires that each vertex occupy an
1785 * odd number of dwords. We use the additional dword to store the output stream
1786 * index as well as a flag to indicate whether this vertex ends a primitive
1787 * for rasterization.
1788 *
1789 * Swizzling is required to satisfy points 1 and 2 simultaneously.
1790 *
1791 * Vertices are stored in export order (gsthread * gs.vertices_out + emitidx).
1792 * Indices are swizzled in groups of 32, which ensures point 1 without
1793 * disturbing point 2.
1794 *
1795 * \return an LDS pointer to type {[N x i32], [4 x i8]}
1796 */
1797static LLVMValueRef ngg_gs_vertex_ptr(struct si_shader_context *ctx, LLVMValueRef vertexidx)
1798{
1799   struct si_shader_selector *sel = ctx->shader->selector;
1800   LLVMBuilderRef builder = ctx->ac.builder;
1801   LLVMValueRef storage = ngg_gs_get_vertex_storage(ctx);
1802
1803   /* gs.vertices_out = 2^(write_stride_2exp) * some odd number */
1804   unsigned write_stride_2exp = ffs(sel->info.base.gs.vertices_out) - 1;
1805   if (write_stride_2exp) {
1806      LLVMValueRef row = LLVMBuildLShr(builder, vertexidx, LLVMConstInt(ctx->ac.i32, 5, false), "");
1807      LLVMValueRef swizzle = LLVMBuildAnd(
1808         builder, row, LLVMConstInt(ctx->ac.i32, (1u << write_stride_2exp) - 1, false), "");
1809      vertexidx = LLVMBuildXor(builder, vertexidx, swizzle, "");
1810   }
1811
1812   return ac_build_gep0(&ctx->ac, storage, vertexidx);
1813}
1814
1815static LLVMValueRef ngg_gs_emit_vertex_ptr(struct si_shader_context *ctx, LLVMValueRef gsthread,
1816                                           LLVMValueRef emitidx)
1817{
1818   struct si_shader_selector *sel = ctx->shader->selector;
1819   LLVMBuilderRef builder = ctx->ac.builder;
1820   LLVMValueRef tmp;
1821
1822   tmp = LLVMConstInt(ctx->ac.i32, sel->info.base.gs.vertices_out, false);
1823   tmp = LLVMBuildMul(builder, tmp, gsthread, "");
1824   const LLVMValueRef vertexidx = LLVMBuildAdd(builder, tmp, emitidx, "");
1825   return ngg_gs_vertex_ptr(ctx, vertexidx);
1826}
1827
1828static LLVMValueRef ngg_gs_get_emit_output_ptr(struct si_shader_context *ctx,
1829                                               LLVMValueRef vertexptr, unsigned out_idx)
1830{
1831   LLVMValueRef gep_idx[3] = {
1832      ctx->ac.i32_0, /* implied C-style array */
1833      ctx->ac.i32_0, /* first struct entry */
1834      LLVMConstInt(ctx->ac.i32, out_idx, false),
1835   };
1836   return LLVMBuildGEP(ctx->ac.builder, vertexptr, gep_idx, 3, "");
1837}
1838
1839static LLVMValueRef ngg_gs_get_emit_primflag_ptr(struct si_shader_context *ctx,
1840                                                 LLVMValueRef vertexptr, unsigned stream)
1841{
1842   LLVMValueRef gep_idx[3] = {
1843      ctx->ac.i32_0, /* implied C-style array */
1844      ctx->ac.i32_1, /* second struct entry */
1845      LLVMConstInt(ctx->ac.i32, stream, false),
1846   };
1847   return LLVMBuildGEP(ctx->ac.builder, vertexptr, gep_idx, 3, "");
1848}
1849
1850void gfx10_ngg_gs_emit_vertex(struct si_shader_context *ctx, unsigned stream, LLVMValueRef *addrs)
1851{
1852   const struct si_shader_selector *sel = ctx->shader->selector;
1853   const struct si_shader_info *info = &sel->info;
1854   LLVMBuilderRef builder = ctx->ac.builder;
1855   LLVMValueRef tmp;
1856   const LLVMValueRef vertexidx = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_next_vertex[stream], "");
1857
1858   /* If this thread has already emitted the declared maximum number of
1859    * vertices, skip the write: excessive vertex emissions are not
1860    * supposed to have any effect.
1861    */
1862   const LLVMValueRef can_emit =
1863      LLVMBuildICmp(builder, LLVMIntULT, vertexidx,
1864                    LLVMConstInt(ctx->ac.i32, sel->info.base.gs.vertices_out, false), "");
1865
1866   tmp = LLVMBuildAdd(builder, vertexidx, ctx->ac.i32_1, "");
1867   tmp = LLVMBuildSelect(builder, can_emit, tmp, vertexidx, "");
1868   LLVMBuildStore(builder, tmp, ctx->gs_next_vertex[stream]);
1869
1870   ac_build_ifcc(&ctx->ac, can_emit, 9001);
1871
1872   const LLVMValueRef vertexptr = ngg_gs_emit_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx), vertexidx);
1873   unsigned out_idx = 0;
1874   for (unsigned i = 0; i < info->num_outputs; i++) {
1875      for (unsigned chan = 0; chan < 4; chan++, out_idx++) {
1876         if (!(info->output_usagemask[i] & (1 << chan)) ||
1877             ((info->output_streams[i] >> (2 * chan)) & 3) != stream)
1878            continue;
1879
1880         LLVMValueRef out_val = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i + chan], "");
1881         LLVMTypeRef as_int = ac_to_integer_type(&ctx->ac, ctx->ac.f32);
1882         out_val = LLVMBuildBitCast(ctx->ac.builder, out_val, as_int, "");
1883         LLVMBuildStore(builder, out_val, ngg_gs_get_emit_output_ptr(ctx, vertexptr, out_idx));
1884      }
1885   }
1886   assert(out_idx * 4 == info->gsvs_vertex_size);
1887
1888   /* Determine and store whether this vertex completed a primitive. */
1889   const LLVMValueRef curverts = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_curprim_verts[stream], "");
1890
1891   tmp = LLVMConstInt(ctx->ac.i32, u_vertices_per_prim(sel->info.base.gs.output_primitive) - 1, false);
1892   const LLVMValueRef iscompleteprim = LLVMBuildICmp(builder, LLVMIntUGE, curverts, tmp, "");
1893
1894   /* Since the geometry shader emits triangle strips, we need to
1895    * track which primitive is odd and swap vertex indices to get
1896    * the correct vertex order.
1897    */
1898   LLVMValueRef is_odd = ctx->ac.i1false;
1899   if (stream == 0 && u_vertices_per_prim(sel->info.base.gs.output_primitive) == 3) {
1900      tmp = LLVMBuildAnd(builder, curverts, ctx->ac.i32_1, "");
1901      is_odd = LLVMBuildICmp(builder, LLVMIntEQ, tmp, ctx->ac.i32_1, "");
1902   }
1903
1904   tmp = LLVMBuildAdd(builder, curverts, ctx->ac.i32_1, "");
1905   LLVMBuildStore(builder, tmp, ctx->gs_curprim_verts[stream]);
1906
1907   /* The per-vertex primitive flag encoding:
1908    *   bit 0: whether this vertex finishes a primitive
1909    *   bit 1: whether the primitive is odd (if we are emitting triangle strips)
1910    */
1911   tmp = LLVMBuildZExt(builder, iscompleteprim, ctx->ac.i8, "");
1912   tmp = LLVMBuildOr(
1913      builder, tmp,
1914      LLVMBuildShl(builder, LLVMBuildZExt(builder, is_odd, ctx->ac.i8, ""), ctx->ac.i8_1, ""), "");
1915   LLVMBuildStore(builder, tmp, ngg_gs_get_emit_primflag_ptr(ctx, vertexptr, stream));
1916
1917   tmp = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_generated_prims[stream], "");
1918   tmp = LLVMBuildAdd(builder, tmp, LLVMBuildZExt(builder, iscompleteprim, ctx->ac.i32, ""), "");
1919   LLVMBuildStore(builder, tmp, ctx->gs_generated_prims[stream]);
1920
1921   ac_build_endif(&ctx->ac, 9001);
1922}
1923
1924void gfx10_ngg_gs_emit_begin(struct si_shader_context *ctx)
1925{
1926   /* Zero out the part of LDS scratch that is used to accumulate the
1927    * per-stream generated primitive count.
1928    */
1929   LLVMBuilderRef builder = ctx->ac.builder;
1930   LLVMValueRef scratchptr = ctx->gs_ngg_scratch;
1931   LLVMValueRef tid = gfx10_get_thread_id_in_tg(ctx);
1932   LLVMValueRef tmp;
1933
1934   tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, LLVMConstInt(ctx->ac.i32, 4, false), "");
1935   ac_build_ifcc(&ctx->ac, tmp, 5090);
1936   {
1937      LLVMValueRef ptr = ac_build_gep0(&ctx->ac, scratchptr, tid);
1938      LLVMBuildStore(builder, ctx->ac.i32_0, ptr);
1939   }
1940   ac_build_endif(&ctx->ac, 5090);
1941
1942   if (ctx->screen->info.gfx_level < GFX11) {
1943      tmp = si_is_gs_thread(ctx);
1944      ac_build_ifcc(&ctx->ac, tmp, 15090);
1945         {
1946            tmp = GET_FIELD(ctx, GS_STATE_PIPELINE_STATS_EMU);
1947            tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
1948            ac_build_ifcc(&ctx->ac, tmp, 5109); /* if (GS_PIPELINE_STATS_EMU) */
1949            LLVMValueRef args[] = {
1950               ctx->ac.i32_1,
1951               ngg_get_emulated_counters_buf(ctx),
1952               LLVMConstInt(ctx->ac.i32,
1953                            si_query_pipestat_end_dw_offset(ctx->screen, PIPE_STAT_QUERY_GS_INVOCATIONS) * 4,
1954                            false),
1955               ctx->ac.i32_0,                            /* soffset */
1956               ctx->ac.i32_0,                            /* cachepolicy */
1957            };
1958
1959            ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5, 0);
1960            ac_build_endif(&ctx->ac, 5109);
1961         }
1962      ac_build_endif(&ctx->ac, 15090);
1963   }
1964
1965   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1966   ac_build_s_barrier(&ctx->ac, ctx->stage);
1967}
1968
1969void gfx10_ngg_gs_build_end(struct si_shader_context *ctx)
1970{
1971   const struct si_shader_selector *sel = ctx->shader->selector;
1972   const struct si_shader_info *info = &sel->info;
1973   const unsigned verts_per_prim = u_vertices_per_prim(sel->info.base.gs.output_primitive);
1974   LLVMBuilderRef builder = ctx->ac.builder;
1975   LLVMValueRef i8_0 = LLVMConstInt(ctx->ac.i8, 0, false);
1976   LLVMValueRef tmp, tmp2;
1977
1978   /* Zero out remaining (non-emitted) primitive flags.
1979    *
1980    * Note: Alternatively, we could pass the relevant gs_next_vertex to
1981    *       the emit threads via LDS. This is likely worse in the expected
1982    *       typical case where each GS thread emits the full set of
1983    *       vertices.
1984    */
1985   for (unsigned stream = 0; stream < 4; ++stream) {
1986      if (!info->num_stream_output_components[stream])
1987         continue;
1988
1989      const LLVMValueRef gsthread = gfx10_get_thread_id_in_tg(ctx);
1990
1991      ac_build_bgnloop(&ctx->ac, 5100);
1992
1993      const LLVMValueRef vertexidx = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_next_vertex[stream], "");
1994      tmp = LLVMBuildICmp(builder, LLVMIntUGE, vertexidx,
1995                          LLVMConstInt(ctx->ac.i32, sel->info.base.gs.vertices_out, false), "");
1996      ac_build_ifcc(&ctx->ac, tmp, 5101);
1997      ac_build_break(&ctx->ac);
1998      ac_build_endif(&ctx->ac, 5101);
1999
2000      tmp = LLVMBuildAdd(builder, vertexidx, ctx->ac.i32_1, "");
2001      LLVMBuildStore(builder, tmp, ctx->gs_next_vertex[stream]);
2002
2003      tmp = ngg_gs_emit_vertex_ptr(ctx, gsthread, vertexidx);
2004      LLVMBuildStore(builder, i8_0, ngg_gs_get_emit_primflag_ptr(ctx, tmp, stream));
2005
2006      ac_build_endloop(&ctx->ac, 5100);
2007   }
2008
2009   /* Accumulate generated primitives counts across the entire threadgroup. */
2010   for (unsigned stream = 0; stream < 4; ++stream) {
2011      if (!info->num_stream_output_components[stream])
2012         continue;
2013
2014      LLVMValueRef numprims = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_generated_prims[stream], "");
2015      numprims = ac_build_reduce(&ctx->ac, numprims, nir_op_iadd, ctx->ac.wave_size);
2016
2017      tmp = LLVMBuildICmp(builder, LLVMIntEQ, ac_get_thread_id(&ctx->ac), ctx->ac.i32_0, "");
2018      ac_build_ifcc(&ctx->ac, tmp, 5105);
2019      {
2020         LLVMBuildAtomicRMW(
2021            builder, LLVMAtomicRMWBinOpAdd,
2022            ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, LLVMConstInt(ctx->ac.i32, stream, false)),
2023            numprims, LLVMAtomicOrderingMonotonic, false);
2024      }
2025      ac_build_endif(&ctx->ac, 5105);
2026   }
2027
2028   ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
2029
2030   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2031   ac_build_s_barrier(&ctx->ac, ctx->stage);
2032
2033   const LLVMValueRef tid = gfx10_get_thread_id_in_tg(ctx);
2034   LLVMValueRef num_emit_threads = ngg_get_prim_cnt(ctx);
2035
2036   /* Streamout */
2037   if (ctx->so.num_outputs) {
2038      struct ngg_streamout nggso = {};
2039
2040      nggso.num_vertices = LLVMConstInt(ctx->ac.i32, verts_per_prim, false);
2041
2042      LLVMValueRef vertexptr = ngg_gs_vertex_ptr(ctx, tid);
2043      for (unsigned stream = 0; stream < 4; ++stream) {
2044         if (!info->num_stream_output_components[stream])
2045            continue;
2046
2047         tmp = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, vertexptr, stream), "");
2048         tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2049         tmp2 = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2050         nggso.prim_enable[stream] = LLVMBuildAnd(builder, tmp, tmp2, "");
2051      }
2052
2053      for (unsigned i = 0; i < verts_per_prim; ++i) {
2054         tmp = LLVMBuildSub(builder, tid, LLVMConstInt(ctx->ac.i32, verts_per_prim - i - 1, false),
2055                            "");
2056         tmp = ngg_gs_vertex_ptr(ctx, tmp);
2057         nggso.vertices[i] = ac_build_gep0(&ctx->ac, tmp, ctx->ac.i32_0);
2058      }
2059
2060      build_streamout(ctx, &nggso);
2061   }
2062
2063   /* Write shader query data. */
2064   if (ctx->screen->use_ngg_streamout) {
2065      tmp = GET_FIELD(ctx, GS_STATE_STREAMOUT_QUERY_ENABLED);
2066      tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2067      ac_build_ifcc(&ctx->ac, tmp, 5109); /* if (STREAMOUT_QUERY_ENABLED) */
2068      unsigned num_query_comps = ctx->so.num_outputs ? 8 : 4;
2069      tmp = LLVMBuildICmp(builder, LLVMIntULT, tid,
2070                          LLVMConstInt(ctx->ac.i32, num_query_comps, false), "");
2071      ac_build_ifcc(&ctx->ac, tmp, 5110);
2072      {
2073         LLVMValueRef offset;
2074         tmp = tid;
2075         if (ctx->so.num_outputs)
2076            tmp = LLVMBuildAnd(builder, tmp, LLVMConstInt(ctx->ac.i32, 3, false), "");
2077         offset = LLVMBuildNUWMul(builder, tmp, LLVMConstInt(ctx->ac.i32, 32, false), "");
2078         if (ctx->so.num_outputs) {
2079            tmp = LLVMBuildLShr(builder, tid, LLVMConstInt(ctx->ac.i32, 2, false), "");
2080            tmp = LLVMBuildNUWMul(builder, tmp, LLVMConstInt(ctx->ac.i32, 8, false), "");
2081            offset = LLVMBuildAdd(builder, offset, tmp, "");
2082         }
2083
2084         tmp = LLVMBuildLoad2(builder, ctx->ac.i32, ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tid), "");
2085         LLVMValueRef args[] = {
2086            tmp,           ngg_get_query_buf(ctx),
2087            offset,        LLVMConstInt(ctx->ac.i32, 16, false), /* soffset */
2088            ctx->ac.i32_0,                                       /* cachepolicy */
2089         };
2090         ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5,
2091                            0);
2092      }
2093      ac_build_endif(&ctx->ac, 5110);
2094      ac_build_endif(&ctx->ac, 5109);
2095   }
2096
2097   /* Cull primitives. */
2098   if (ctx->shader->key.ge.opt.ngg_culling) {
2099      assert(info->num_stream_output_components[0]);
2100
2101      LLVMValueRef gs_vtxptr = ngg_gs_vertex_ptr(ctx, tid);
2102      LLVMValueRef live = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, gs_vtxptr, 0), "");
2103      live = LLVMBuildTrunc(builder, live, ctx->ac.i1, "");
2104      LLVMValueRef is_emit = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2105      LLVMValueRef prim_enable = LLVMBuildAnd(builder, live, is_emit, "");
2106
2107      /* Wait for streamout to finish before we kill primitives. */
2108      if (ctx->so.num_outputs) {
2109         ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2110         ac_build_s_barrier(&ctx->ac, ctx->stage);
2111      }
2112
2113      ac_build_ifcc(&ctx->ac, prim_enable, 0);
2114      {
2115         LLVMValueRef vtxptr[3] = {};
2116         LLVMValueRef pos[3][4] = {};
2117
2118         for (unsigned i = 0; i < verts_per_prim; i++) {
2119            tmp = LLVMBuildSub(builder, tid, LLVMConstInt(ctx->ac.i32, verts_per_prim - i - 1, false), "");
2120            vtxptr[i] = ac_build_gep0(&ctx->ac, ngg_gs_vertex_ptr(ctx, tmp), ctx->ac.i32_0);
2121         }
2122
2123         for (unsigned i = 0; i < info->num_outputs; i++) {
2124            /* If the stream index is non-zero for all channels, skip the output. */
2125            if (info->output_streams[i] & 0x3 &&
2126                (info->output_streams[i] >> 2) & 0x3 &&
2127                (info->output_streams[i] >> 4) & 0x3 &&
2128                (info->output_streams[i] >> 6) & 0x3)
2129               continue;
2130
2131            switch (info->output_semantic[i]) {
2132            case VARYING_SLOT_POS:
2133               /* Load the positions from LDS. */
2134               for (unsigned vert = 0; vert < verts_per_prim; vert++) {
2135                  for (unsigned comp = 0; comp < 4; comp++) {
2136                     /* Z is not needed. */
2137                     if (comp == 2)
2138                        continue;
2139
2140                     tmp = ac_build_gep0(&ctx->ac, vtxptr[vert],
2141                                         LLVMConstInt(ctx->ac.i32, 4 * i + comp, false));
2142                     pos[vert][comp] = LLVMBuildLoad(builder, tmp, "");
2143                     pos[vert][comp] = ac_to_float(&ctx->ac, pos[vert][comp]);
2144                  }
2145               }
2146
2147               /* Divide XY by W. */
2148               for (unsigned vert = 0; vert < verts_per_prim; vert++) {
2149                  for (unsigned comp = 0; comp < 2; comp++)
2150                     pos[vert][comp] = ac_build_fdiv(&ctx->ac, pos[vert][comp], pos[vert][3]);
2151               }
2152               break;
2153            }
2154         }
2155
2156         LLVMValueRef clipdist_accepted = ctx->ac.i1true; /* TODO */
2157         LLVMValueRef accepted = ac_build_alloca(&ctx->ac, ctx->ac.i32, "");
2158
2159         cull_primitive(ctx, pos, clipdist_accepted, accepted, NULL);
2160
2161         accepted = LLVMBuildLoad2(builder, ctx->ac.i32, accepted, "");
2162         LLVMValueRef rejected = LLVMBuildNot(builder, LLVMBuildTrunc(builder, accepted, ctx->ac.i1, ""), "");
2163
2164         ac_build_ifcc(&ctx->ac, rejected, 0);
2165         LLVMBuildStore(builder, ctx->ac.i8_0, ngg_gs_get_emit_primflag_ptr(ctx, gs_vtxptr, 0));
2166         ac_build_endif(&ctx->ac, 0);
2167      }
2168      ac_build_endif(&ctx->ac, 0);
2169
2170      ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2171      ac_build_s_barrier(&ctx->ac, ctx->stage);
2172   }
2173
2174   /* Determine vertex liveness. */
2175   LLVMValueRef vertliveptr = ac_build_alloca(&ctx->ac, ctx->ac.i1, "vertexlive");
2176
2177   tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2178   ac_build_ifcc(&ctx->ac, tmp, 5120);
2179   {
2180      for (unsigned i = 0; i < verts_per_prim; ++i) {
2181         const LLVMValueRef primidx =
2182            LLVMBuildAdd(builder, tid, LLVMConstInt(ctx->ac.i32, i, false), "");
2183
2184         if (i > 0) {
2185            tmp = LLVMBuildICmp(builder, LLVMIntULT, primidx, num_emit_threads, "");
2186            ac_build_ifcc(&ctx->ac, tmp, 5121 + i);
2187         }
2188
2189         /* Load primitive liveness */
2190         tmp = ngg_gs_vertex_ptr(ctx, primidx);
2191         tmp = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 0), "");
2192         const LLVMValueRef primlive = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2193
2194         tmp = LLVMBuildLoad2(builder, ctx->ac.i1, vertliveptr, "");
2195         tmp = LLVMBuildOr(builder, tmp, primlive, ""), LLVMBuildStore(builder, tmp, vertliveptr);
2196
2197         if (i > 0)
2198            ac_build_endif(&ctx->ac, 5121 + i);
2199      }
2200   }
2201   ac_build_endif(&ctx->ac, 5120);
2202
2203   /* Inclusive scan addition across the current wave. */
2204   LLVMValueRef vertlive = LLVMBuildLoad2(builder, ctx->ac.i1, vertliveptr, "");
2205   struct ac_wg_scan vertlive_scan = {};
2206   vertlive_scan.stage = ctx->stage;
2207   vertlive_scan.op = nir_op_iadd;
2208   vertlive_scan.enable_reduce = true;
2209   vertlive_scan.enable_exclusive = true;
2210   vertlive_scan.src = vertlive;
2211   vertlive_scan.scratch = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, ctx->ac.i32_0);
2212   vertlive_scan.waveidx = get_wave_id_in_tg(ctx);
2213   vertlive_scan.numwaves = get_tgsize(ctx);
2214   vertlive_scan.maxwaves = DIV_ROUND_UP(256, ctx->ac.wave_size);
2215
2216   ac_build_wg_scan(&ctx->ac, &vertlive_scan);
2217
2218   /* Skip all exports (including index exports) when possible. */
2219   LLVMValueRef have_exports =
2220      LLVMBuildICmp(builder, LLVMIntNE, vertlive_scan.result_reduce, ctx->ac.i32_0, "");
2221   num_emit_threads = LLVMBuildSelect(builder, have_exports, num_emit_threads, ctx->ac.i32_0, "");
2222
2223   /* Allocate export space. Send this message as early as possible, to
2224    * hide the latency of the SQ <-> SPI roundtrip.
2225    */
2226   ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), vertlive_scan.result_reduce,
2227                                 num_emit_threads);
2228
2229   /* Setup the reverse vertex compaction permutation. We re-use stream 1
2230    * of the primitive liveness flags, relying on the fact that each
2231    * threadgroup can have at most 256 threads. */
2232   ac_build_ifcc(&ctx->ac, vertlive, 5130);
2233   {
2234      tmp = ngg_gs_vertex_ptr(ctx, vertlive_scan.result_exclusive);
2235      tmp2 = LLVMBuildTrunc(builder, tid, ctx->ac.i8, "");
2236      LLVMBuildStore(builder, tmp2, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 1));
2237   }
2238   ac_build_endif(&ctx->ac, 5130);
2239
2240   ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2241   ac_build_s_barrier(&ctx->ac, ctx->stage);
2242
2243   /* Export primitive data */
2244   tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2245   ac_build_ifcc(&ctx->ac, tmp, 5140);
2246   {
2247      LLVMValueRef flags;
2248      struct ac_ngg_prim prim = {};
2249      prim.num_vertices = verts_per_prim;
2250
2251      tmp = ngg_gs_vertex_ptr(ctx, tid);
2252      flags = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 0), "");
2253      prim.isnull = LLVMBuildNot(builder, LLVMBuildTrunc(builder, flags, ctx->ac.i1, ""), "");
2254      prim.edgeflags = ctx->ac.i32_0;
2255
2256      for (unsigned i = 0; i < verts_per_prim; ++i) {
2257         prim.index[i] = LLVMBuildSub(builder, vertlive_scan.result_exclusive,
2258                                      LLVMConstInt(ctx->ac.i32, verts_per_prim - i - 1, false), "");
2259      }
2260
2261      /* Geometry shaders output triangle strips, but NGG expects triangles. */
2262      if (verts_per_prim == 3) {
2263         LLVMValueRef is_odd = LLVMBuildLShr(builder, flags, ctx->ac.i8_1, "");
2264         is_odd = LLVMBuildTrunc(builder, is_odd, ctx->ac.i1, "");
2265         LLVMValueRef flatshade_first = LLVMBuildICmp(
2266            builder, LLVMIntEQ, GET_FIELD(ctx, GS_STATE_PROVOKING_VTX_INDEX), ctx->ac.i32_0, "");
2267
2268         ac_build_triangle_strip_indices_to_triangle(&ctx->ac, is_odd, flatshade_first, prim.index);
2269      }
2270
2271      ac_build_export_prim(&ctx->ac, &prim);
2272
2273      if (ctx->screen->info.gfx_level < GFX11) {
2274         tmp = GET_FIELD(ctx, GS_STATE_PIPELINE_STATS_EMU);
2275         tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2276         ac_build_ifcc(&ctx->ac, tmp, 5229); /* if (GS_PIPELINE_STATS_EMU) */
2277         ac_build_ifcc(&ctx->ac, LLVMBuildNot(builder, prim.isnull, ""), 5237);
2278         {
2279            LLVMValueRef args[] = {
2280               ctx->ac.i32_1,
2281               ngg_get_emulated_counters_buf(ctx),
2282               LLVMConstInt(ctx->ac.i32,
2283                            si_query_pipestat_end_dw_offset(ctx->screen, PIPE_STAT_QUERY_GS_PRIMITIVES) * 4,
2284                            false),
2285               ctx->ac.i32_0,                            /* soffset */
2286               ctx->ac.i32_0,                            /* cachepolicy */
2287            };
2288
2289            ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5, 0);
2290         }
2291         ac_build_endif(&ctx->ac, 5237);
2292         ac_build_endif(&ctx->ac, 5229);
2293      }
2294   }
2295   ac_build_endif(&ctx->ac, 5140);
2296
2297   /* Export position and parameter data */
2298   LLVMValueRef num_export_threads = vertlive_scan.result_reduce;
2299   tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, num_export_threads, "");
2300   ac_build_ifcc(&ctx->ac, tmp, 5145);
2301   {
2302      struct si_shader_output_values outputs[PIPE_MAX_SHADER_OUTPUTS];
2303
2304      tmp = ngg_gs_vertex_ptr(ctx, tid);
2305      tmp = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 1), "");
2306      tmp = LLVMBuildZExt(builder, tmp, ctx->ac.i32, "");
2307      const LLVMValueRef vertexptr = ngg_gs_vertex_ptr(ctx, tmp);
2308
2309      unsigned out_idx = 0;
2310      for (unsigned i = 0; i < info->num_outputs; i++) {
2311         outputs[i].semantic = info->output_semantic[i];
2312
2313         for (unsigned j = 0; j < 4; j++, out_idx++) {
2314            tmp = ngg_gs_get_emit_output_ptr(ctx, vertexptr, out_idx);
2315            tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
2316            assert(LLVMGetTypeKind(LLVMTypeOf(tmp)) != LLVMPointerTypeKind);
2317            outputs[i].values[j] = ac_to_float(&ctx->ac, tmp);
2318            outputs[i].vertex_streams = info->output_streams[i];
2319         }
2320      }
2321
2322      si_llvm_build_vs_exports(ctx, num_export_threads, outputs, info->num_outputs);
2323   }
2324   ac_build_endif(&ctx->ac, 5145);
2325}
2326
2327static void clamp_gsprims_to_esverts(unsigned *max_gsprims, unsigned max_esverts,
2328                                     unsigned min_verts_per_prim, bool use_adjacency)
2329{
2330   unsigned max_reuse = max_esverts - min_verts_per_prim;
2331   if (use_adjacency)
2332      max_reuse /= 2;
2333   *max_gsprims = MIN2(*max_gsprims, 1 + max_reuse);
2334}
2335
2336unsigned gfx10_ngg_get_scratch_dw_size(struct si_shader *shader)
2337{
2338   const struct si_shader_selector *sel = shader->selector;
2339
2340   if (sel->stage == MESA_SHADER_GEOMETRY && si_shader_uses_streamout(shader))
2341      return 44;
2342
2343   return 8;
2344}
2345
2346/**
2347 * Determine subgroup information like maximum number of vertices and prims.
2348 *
2349 * This happens before the shader is uploaded, since LDS relocations during
2350 * upload depend on the subgroup size.
2351 */
2352bool gfx10_ngg_calculate_subgroup_info(struct si_shader *shader)
2353{
2354   const struct si_shader_selector *gs_sel = shader->selector;
2355   const struct si_shader_selector *es_sel =
2356      shader->previous_stage_sel ? shader->previous_stage_sel : gs_sel;
2357   const gl_shader_stage gs_stage = gs_sel->stage;
2358   const unsigned gs_num_invocations = MAX2(gs_sel->info.base.gs.invocations, 1);
2359   const unsigned input_prim = si_get_input_prim(gs_sel, &shader->key);
2360   const bool use_adjacency =
2361      input_prim >= PIPE_PRIM_LINES_ADJACENCY && input_prim <= PIPE_PRIM_TRIANGLE_STRIP_ADJACENCY;
2362   const unsigned max_verts_per_prim = u_vertices_per_prim(input_prim);
2363   const unsigned min_verts_per_prim = gs_stage == MESA_SHADER_GEOMETRY ? max_verts_per_prim : 1;
2364
2365   /* All these are in dwords: */
2366   /* GE can only use 8K dwords (32KB) of LDS per workgroup.
2367    */
2368   const unsigned max_lds_size = 8 * 1024 - gfx10_ngg_get_scratch_dw_size(shader);
2369   const unsigned target_lds_size = max_lds_size;
2370   unsigned esvert_lds_size = 0;
2371   unsigned gsprim_lds_size = 0;
2372
2373   /* All these are per subgroup: */
2374   const unsigned min_esverts =
2375      gs_sel->screen->info.gfx_level >= GFX11 ? 3 : /* gfx11 requires at least 1 primitive per TG */
2376      gs_sel->screen->info.gfx_level >= GFX10_3 ? 29 : (24 - 1 + max_verts_per_prim);
2377   bool max_vert_out_per_gs_instance = false;
2378   unsigned max_gsprims_base = gs_sel->screen->ngg_subgroup_size; /* default prim group size clamp */
2379   unsigned max_esverts_base = gs_sel->screen->ngg_subgroup_size;
2380
2381   if (gs_stage == MESA_SHADER_GEOMETRY) {
2382      bool force_multi_cycling = false;
2383      unsigned max_out_verts_per_gsprim = gs_sel->info.base.gs.vertices_out * gs_num_invocations;
2384
2385retry_select_mode:
2386      if (max_out_verts_per_gsprim <= 256 && !force_multi_cycling) {
2387         if (max_out_verts_per_gsprim) {
2388            max_gsprims_base = MIN2(max_gsprims_base, 256 / max_out_verts_per_gsprim);
2389         }
2390      } else {
2391         /* Use special multi-cycling mode in which each GS
2392          * instance gets its own subgroup. Does not work with
2393          * tessellation. */
2394         max_vert_out_per_gs_instance = true;
2395         max_gsprims_base = 1;
2396         max_out_verts_per_gsprim = gs_sel->info.base.gs.vertices_out;
2397      }
2398
2399      esvert_lds_size = es_sel->info.esgs_itemsize / 4;
2400      gsprim_lds_size = (gs_sel->info.gsvs_vertex_size / 4 + 1) * max_out_verts_per_gsprim;
2401
2402      if (gsprim_lds_size > target_lds_size && !force_multi_cycling) {
2403         if (gs_sel->tess_turns_off_ngg || es_sel->stage != MESA_SHADER_TESS_EVAL) {
2404            force_multi_cycling = true;
2405            goto retry_select_mode;
2406         }
2407      }
2408   } else {
2409      /* VS and TES. */
2410      /* LDS size for passing data from ES to GS. */
2411      esvert_lds_size = ngg_nogs_vertex_size(shader);
2412   }
2413
2414   unsigned max_gsprims = max_gsprims_base;
2415   unsigned max_esverts = max_esverts_base;
2416
2417   if (esvert_lds_size)
2418      max_esverts = MIN2(max_esverts, target_lds_size / esvert_lds_size);
2419   if (gsprim_lds_size)
2420      max_gsprims = MIN2(max_gsprims, target_lds_size / gsprim_lds_size);
2421
2422   max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2423   clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, use_adjacency);
2424   assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
2425
2426   if (esvert_lds_size || gsprim_lds_size) {
2427      /* Now that we have a rough proportionality between esverts
2428       * and gsprims based on the primitive type, scale both of them
2429       * down simultaneously based on required LDS space.
2430       *
2431       * We could be smarter about this if we knew how much vertex
2432       * reuse to expect.
2433       */
2434      unsigned lds_total = max_esverts * esvert_lds_size + max_gsprims * gsprim_lds_size;
2435      if (lds_total > target_lds_size) {
2436         max_esverts = max_esverts * target_lds_size / lds_total;
2437         max_gsprims = max_gsprims * target_lds_size / lds_total;
2438
2439         max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2440         clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, use_adjacency);
2441         assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
2442      }
2443   }
2444
2445   /* Round up towards full wave sizes for better ALU utilization. */
2446   if (!max_vert_out_per_gs_instance) {
2447      unsigned orig_max_esverts;
2448      unsigned orig_max_gsprims;
2449      do {
2450         orig_max_esverts = max_esverts;
2451         orig_max_gsprims = max_gsprims;
2452
2453         max_esverts = align(max_esverts, shader->wave_size);
2454         max_esverts = MIN2(max_esverts, max_esverts_base);
2455         if (esvert_lds_size)
2456            max_esverts =
2457               MIN2(max_esverts, (max_lds_size - max_gsprims * gsprim_lds_size) / esvert_lds_size);
2458         max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2459
2460         /* Hardware restriction: minimum value of max_esverts */
2461         max_esverts = MAX2(max_esverts, min_esverts);
2462
2463         max_gsprims = align(max_gsprims, shader->wave_size);
2464         max_gsprims = MIN2(max_gsprims, max_gsprims_base);
2465         if (gsprim_lds_size) {
2466            /* Don't count unusable vertices to the LDS size. Those are vertices above
2467             * the maximum number of vertices that can occur in the workgroup,
2468             * which is e.g. max_gsprims * 3 for triangles.
2469             */
2470            unsigned usable_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2471            max_gsprims =
2472               MIN2(max_gsprims, (max_lds_size - usable_esverts * esvert_lds_size) / gsprim_lds_size);
2473         }
2474         clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, use_adjacency);
2475         assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
2476      } while (orig_max_esverts != max_esverts || orig_max_gsprims != max_gsprims);
2477
2478      /* Verify the restriction. */
2479      assert(max_esverts >= min_esverts);
2480   } else {
2481      max_esverts = MAX2(max_esverts, min_esverts);
2482   }
2483
2484   unsigned max_out_vertices =
2485      max_vert_out_per_gs_instance
2486         ? gs_sel->info.base.gs.vertices_out
2487         : gs_stage == MESA_SHADER_GEOMETRY
2488              ? max_gsprims * gs_num_invocations * gs_sel->info.base.gs.vertices_out
2489              : max_esverts;
2490   assert(max_out_vertices <= 256);
2491
2492   unsigned prim_amp_factor = 1;
2493   if (gs_stage == MESA_SHADER_GEOMETRY) {
2494      /* Number of output primitives per GS input primitive after
2495       * GS instancing. */
2496      prim_amp_factor = gs_sel->info.base.gs.vertices_out;
2497   }
2498
2499   shader->ngg.hw_max_esverts = max_esverts;
2500   shader->ngg.max_gsprims = max_gsprims;
2501   shader->ngg.max_out_verts = max_out_vertices;
2502   shader->ngg.prim_amp_factor = prim_amp_factor;
2503   shader->ngg.max_vert_out_per_gs_instance = max_vert_out_per_gs_instance;
2504
2505   /* Don't count unusable vertices. */
2506   shader->gs_info.esgs_ring_size = MIN2(max_esverts, max_gsprims * max_verts_per_prim) *
2507                                    esvert_lds_size;
2508   shader->ngg.ngg_emit_size = max_gsprims * gsprim_lds_size;
2509
2510   assert(shader->ngg.hw_max_esverts >= min_esverts); /* HW limitation */
2511
2512   /* If asserts are disabled, we use the same conditions to return false */
2513   return max_esverts >= max_verts_per_prim && max_gsprims >= 1 &&
2514          max_out_vertices <= 256 &&
2515          shader->ngg.hw_max_esverts >= min_esverts;
2516}
2517