1/*
2 * Copyright © 2022 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 *    Timur Kristóf
25 *
26 */
27
28#include "nir.h"
29#include "nir_builder.h"
30#include "util/u_math.h"
31
32typedef struct {
33   uint32_t task_count_shared_addr;
34} lower_task_nv_state;
35
36typedef struct {
37   /* If true, lower all task_payload I/O to use shared memory. */
38   bool payload_in_shared;
39   /* Shared memory address where task_payload will be located. */
40   uint32_t payload_shared_addr;
41} lower_task_state;
42
43static bool
44lower_nv_task_output(nir_builder *b,
45                     nir_instr *instr,
46                     void *state)
47{
48   if (instr->type != nir_instr_type_intrinsic)
49      return false;
50
51   lower_task_nv_state *s = (lower_task_nv_state *) state;
52   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
53
54   switch (intrin->intrinsic) {
55   case nir_intrinsic_load_output: {
56      b->cursor = nir_after_instr(instr);
57      nir_ssa_def *load =
58         nir_load_shared(b, 1, 32, nir_imm_int(b, s->task_count_shared_addr));
59      nir_ssa_def_rewrite_uses(&intrin->dest.ssa, load);
60      nir_instr_remove(instr);
61      return true;
62   }
63
64   case nir_intrinsic_store_output: {
65      b->cursor = nir_after_instr(instr);
66      nir_ssa_def *store_val = intrin->src[0].ssa;
67      nir_store_shared(b, store_val, nir_imm_int(b, s->task_count_shared_addr));
68      nir_instr_remove(instr);
69      return true;
70   }
71
72   default:
73      return false;
74   }
75}
76
77static void
78append_launch_mesh_workgroups_to_nv_task(nir_builder *b,
79                                         lower_task_nv_state *s)
80{
81   /* At the beginning of the shader, write 0 to the task count.
82    * This ensures that 0 mesh workgroups are launched when the
83    * shader doesn't write the TASK_COUNT output.
84    */
85   b->cursor = nir_before_cf_list(&b->impl->body);
86   nir_ssa_def *zero = nir_imm_int(b, 0);
87   nir_store_shared(b, zero, nir_imm_int(b, s->task_count_shared_addr));
88
89   nir_scoped_barrier(b,
90         .execution_scope = NIR_SCOPE_WORKGROUP,
91         .memory_scope = NIR_SCOPE_WORKGROUP,
92         .memory_semantics = NIR_MEMORY_RELEASE,
93         .memory_modes = nir_var_mem_shared);
94
95   /* At the end of the shader, read the task count from shared memory
96    * and emit launch_mesh_workgroups.
97    */
98   b->cursor = nir_after_cf_list(&b->impl->body);
99
100   nir_scoped_barrier(b,
101         .execution_scope = NIR_SCOPE_WORKGROUP,
102         .memory_scope = NIR_SCOPE_WORKGROUP,
103         .memory_semantics = NIR_MEMORY_ACQUIRE,
104         .memory_modes = nir_var_mem_shared);
105
106   nir_ssa_def *task_count =
107      nir_load_shared(b, 1, 32, nir_imm_int(b, s->task_count_shared_addr));
108
109   /* NV_mesh_shader doesn't offer to choose which task_payload variable
110    * should be passed to mesh shaders, we just pass all.
111    */
112   uint32_t range = b->shader->info.task_payload_size;
113
114   nir_ssa_def *one = nir_imm_int(b, 1);
115   nir_ssa_def *dispatch_3d = nir_vec3(b, task_count, one, one);
116   nir_launch_mesh_workgroups(b, dispatch_3d, .base = 0, .range = range);
117}
118
119/**
120 * For NV_mesh_shader:
121 * Task shaders only have 1 output, TASK_COUNT which is a 32-bit
122 * unsigned int that contains the 1-dimensional mesh dispatch size.
123 * This output should behave like a shared variable.
124 *
125 * We lower this output to a shared variable and then we emit
126 * the new launch_mesh_workgroups intrinsic at the end of the shader.
127 */
128static void
129nir_lower_nv_task_count(nir_shader *shader)
130{
131   lower_task_nv_state state = {
132      .task_count_shared_addr = ALIGN(shader->info.shared_size, 4),
133   };
134
135   shader->info.shared_size += 4;
136   nir_shader_instructions_pass(shader, lower_nv_task_output,
137                                nir_metadata_none, &state);
138
139   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
140   nir_builder builder;
141   nir_builder_init(&builder, impl);
142
143   append_launch_mesh_workgroups_to_nv_task(&builder, &state);
144   nir_metadata_preserve(impl, nir_metadata_none);
145}
146
147static nir_intrinsic_op
148shared_opcode_for_task_payload(nir_intrinsic_op task_payload_op)
149{
150   switch (task_payload_op) {
151#define OP(O) case nir_intrinsic_task_payload_##O: return nir_intrinsic_shared_##O;
152   OP(atomic_exchange)
153   OP(atomic_comp_swap)
154   OP(atomic_add)
155   OP(atomic_imin)
156   OP(atomic_umin)
157   OP(atomic_imax)
158   OP(atomic_umax)
159   OP(atomic_and)
160   OP(atomic_or)
161   OP(atomic_xor)
162   OP(atomic_fadd)
163   OP(atomic_fmin)
164   OP(atomic_fmax)
165   OP(atomic_fcomp_swap)
166#undef OP
167   case nir_intrinsic_load_task_payload:
168      return nir_intrinsic_load_shared;
169   case nir_intrinsic_store_task_payload:
170      return nir_intrinsic_store_shared;
171   default:
172      unreachable("Invalid task payload atomic");
173   }
174}
175
176static bool
177lower_task_payload_to_shared(nir_builder *b,
178                             nir_intrinsic_instr *intrin,
179                             lower_task_state *s)
180{
181   /* This assumes that shared and task_payload intrinsics
182    * have the same number of sources and same indices.
183    */
184   unsigned base = nir_intrinsic_base(intrin);
185   intrin->intrinsic = shared_opcode_for_task_payload(intrin->intrinsic);
186   nir_intrinsic_set_base(intrin, base + s->payload_shared_addr);
187
188   return true;
189}
190
191static void
192emit_shared_to_payload_copy(nir_builder *b,
193                            uint32_t payload_addr,
194                            uint32_t payload_size,
195                            lower_task_state *s)
196{
197   const unsigned invocations = b->shader->info.workgroup_size[0] *
198                          b->shader->info.workgroup_size[1] *
199                          b->shader->info.workgroup_size[2];
200   const unsigned bytes_per_copy = 16;
201   const unsigned copies_needed = DIV_ROUND_UP(payload_size, bytes_per_copy);
202   const unsigned copies_per_invocation = DIV_ROUND_UP(copies_needed, invocations);
203   const unsigned base_shared_addr = s->payload_shared_addr + payload_addr;
204
205   nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
206   nir_ssa_def *addr = nir_imul_imm(b, invocation_index, bytes_per_copy);
207
208   /* Wait for all previous shared stores to finish.
209    * This is necessary because we placed the payload in shared memory.
210    */
211   nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
212                         .memory_scope = NIR_SCOPE_WORKGROUP,
213                         .memory_semantics = NIR_MEMORY_ACQ_REL,
214                         .memory_modes = nir_var_mem_shared);
215
216   for (unsigned i = 0; i < copies_per_invocation; ++i) {
217      unsigned const_off = bytes_per_copy * invocations * i;
218
219      /* Read from shared memory. */
220      nir_ssa_def *copy =
221         nir_load_shared(b, 4, 32, addr, .align_mul = 16,
222                         .base = base_shared_addr + const_off);
223
224      /* Write to task payload memory. */
225      nir_store_task_payload(b, copy, addr, .base = const_off);
226   }
227}
228
229static bool
230lower_task_launch_mesh_workgroups(nir_builder *b,
231                                  nir_intrinsic_instr *intrin,
232                                  lower_task_state *s)
233{
234   if (s->payload_in_shared) {
235      /* Copy the payload from shared memory.
236       * Because launch_mesh_workgroups may only occur in
237       * workgroup-uniform control flow, here we assume that
238       * all invocations in the workgroup are active and therefore
239       * they can all participate in the copy.
240       *
241       * TODO: Skip the copy when the mesh dispatch size is (0, 0, 0).
242       *       This is problematic because the dispatch size can be divergent,
243       *       and may differ accross subgroups.
244       */
245
246      uint32_t payload_addr = nir_intrinsic_base(intrin);
247      uint32_t payload_size = nir_intrinsic_range(intrin);
248
249      b->cursor = nir_before_instr(&intrin->instr);
250      emit_shared_to_payload_copy(b, payload_addr, payload_size, s);
251   }
252
253   /* The launch_mesh_workgroups intrinsic is a terminating instruction,
254    * so let's delete everything after it.
255    */
256   b->cursor = nir_after_instr(&intrin->instr);
257   nir_block *current_block = nir_cursor_current_block(b->cursor);
258
259   /* Delete following instructions in the current block. */
260   nir_foreach_instr_reverse_safe(instr, current_block) {
261      if (instr == &intrin->instr)
262         break;
263      nir_instr_remove(instr);
264   }
265
266   /* Delete following CF at the same level. */
267   b->cursor = nir_after_instr(&intrin->instr);
268   nir_cf_list extracted;
269   nir_cf_node *end_node = &current_block->cf_node;
270   while (!nir_cf_node_is_last(end_node))
271      end_node = nir_cf_node_next(end_node);
272   nir_cf_extract(&extracted, b->cursor, nir_after_cf_node(end_node));
273   nir_cf_delete(&extracted);
274
275   /* Terminate the task shader. */
276   b->cursor = nir_after_instr(&intrin->instr);
277   nir_jump(b, nir_jump_return);
278
279   return true;
280}
281
282static bool
283lower_task_intrin(nir_builder *b,
284                  nir_instr *instr,
285                  void *state)
286{
287   if (instr->type != nir_instr_type_intrinsic)
288      return false;
289
290   lower_task_state *s = (lower_task_state *) state;
291   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
292
293   switch (intrin->intrinsic) {
294   case nir_intrinsic_task_payload_atomic_add:
295   case nir_intrinsic_task_payload_atomic_imin:
296   case nir_intrinsic_task_payload_atomic_umin:
297   case nir_intrinsic_task_payload_atomic_imax:
298   case nir_intrinsic_task_payload_atomic_umax:
299   case nir_intrinsic_task_payload_atomic_and:
300   case nir_intrinsic_task_payload_atomic_or:
301   case nir_intrinsic_task_payload_atomic_xor:
302   case nir_intrinsic_task_payload_atomic_exchange:
303   case nir_intrinsic_task_payload_atomic_comp_swap:
304   case nir_intrinsic_task_payload_atomic_fadd:
305   case nir_intrinsic_task_payload_atomic_fmin:
306   case nir_intrinsic_task_payload_atomic_fmax:
307   case nir_intrinsic_task_payload_atomic_fcomp_swap:
308   case nir_intrinsic_store_task_payload:
309   case nir_intrinsic_load_task_payload:
310      if (s->payload_in_shared)
311         return lower_task_payload_to_shared(b, intrin, s);
312      return false;
313   case nir_intrinsic_launch_mesh_workgroups:
314      return lower_task_launch_mesh_workgroups(b, intrin, s);
315   default:
316      return false;
317   }
318}
319
320static bool
321uses_task_payload_atomics(nir_shader *shader)
322{
323   nir_foreach_function(func, shader) {
324      if (!func->impl)
325         continue;
326
327      nir_foreach_block(block, func->impl) {
328         nir_foreach_instr(instr, block) {
329            if (instr->type != nir_instr_type_intrinsic)
330               continue;
331
332            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
333            switch (intrin->intrinsic) {
334               case nir_intrinsic_task_payload_atomic_add:
335               case nir_intrinsic_task_payload_atomic_imin:
336               case nir_intrinsic_task_payload_atomic_umin:
337               case nir_intrinsic_task_payload_atomic_imax:
338               case nir_intrinsic_task_payload_atomic_umax:
339               case nir_intrinsic_task_payload_atomic_and:
340               case nir_intrinsic_task_payload_atomic_or:
341               case nir_intrinsic_task_payload_atomic_xor:
342               case nir_intrinsic_task_payload_atomic_exchange:
343               case nir_intrinsic_task_payload_atomic_comp_swap:
344               case nir_intrinsic_task_payload_atomic_fadd:
345               case nir_intrinsic_task_payload_atomic_fmin:
346               case nir_intrinsic_task_payload_atomic_fmax:
347               case nir_intrinsic_task_payload_atomic_fcomp_swap:
348                  return true;
349               default:
350                  break;
351            }
352         }
353      }
354   }
355
356   return false;
357}
358
359/**
360 * Common Task Shader lowering to make the job of the backends easier.
361 *
362 * - Lowers NV_mesh_shader TASK_COUNT output to launch_mesh_workgroups.
363 * - Removes all code after launch_mesh_workgroups, enforcing the
364 *   fact that it's a terminating instruction.
365 * - Ensures that task shaders always have at least one
366 *   launch_mesh_workgroups instruction, so the backend doesn't
367 *   need to implement a special case when the shader doesn't have it.
368 * - Optionally, implements task_payload using shared memory when
369 *   task_payload atomics are used.
370 *   This is useful when the backend is otherwise not capable of
371 *   handling the same atomic features as it can for shared memory.
372 *   If this is used, the backend only has to implement the basic
373 *   load/store operations for task_payload.
374 *
375 * Note, this pass operates on lowered explicit I/O intrinsics, so
376 * it should be called after nir_lower_io + nir_lower_explicit_io.
377 */
378bool
379nir_lower_task_shader(nir_shader *shader,
380                      nir_lower_task_shader_options options)
381{
382   if (shader->info.stage != MESA_SHADER_TASK)
383      return false;
384
385   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
386   nir_builder builder;
387   nir_builder_init(&builder, impl);
388
389   if (shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_TASK_COUNT)) {
390      /* NV_mesh_shader:
391       * If the shader writes TASK_COUNT, lower that to emit
392       * the new launch_mesh_workgroups intrinsic instead.
393       */
394      nir_lower_nv_task_count(shader);
395   } else {
396      /* To make sure that task shaders always have a code path that
397       * executes a launch_mesh_workgroups, let's add one at the end.
398       * If the shader already had a launch_mesh_workgroups by any chance,
399       * this will be removed.
400       */
401      nir_block *last_block = nir_impl_last_block(impl);
402      builder.cursor = nir_after_block_before_jump(last_block);
403      nir_launch_mesh_workgroups(&builder, nir_imm_zero(&builder, 3, 32));
404   }
405
406   bool payload_in_shared = options.payload_to_shared_for_atomics &&
407                            uses_task_payload_atomics(shader);
408
409   lower_task_state state = {
410      .payload_shared_addr = ALIGN(shader->info.shared_size, 16),
411      .payload_in_shared = payload_in_shared,
412   };
413
414   if (payload_in_shared)
415      shader->info.shared_size =
416         state.payload_shared_addr + shader->info.task_payload_size;
417
418   nir_shader_instructions_pass(shader, lower_task_intrin,
419                                nir_metadata_none, &state);
420
421   /* Delete all code that potentially can't be reached due to
422    * launch_mesh_workgroups being a terminating instruction.
423    */
424   nir_lower_returns(shader);
425   bool progress;
426   do {
427      progress = false;
428      NIR_PASS(progress, shader, nir_opt_dead_cf);
429      NIR_PASS(progress, shader, nir_opt_dce);
430   } while (progress);
431   return true;
432}
433