1 /*
2  * Copyright © 2022 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Timur Kristóf
25  *
26  */
27 
28 #include "nir.h"
29 #include "nir_builder.h"
30 #include "util/u_math.h"
31 
32 typedef struct {
33    uint32_t task_count_shared_addr;
34 } lower_task_nv_state;
35 
36 typedef struct {
37    /* If true, lower all task_payload I/O to use shared memory. */
38    bool payload_in_shared;
39    /* Shared memory address where task_payload will be located. */
40    uint32_t payload_shared_addr;
41 } lower_task_state;
42 
43 static bool
lower_nv_task_output(nir_builder *b, nir_instr *instr, void *state)44 lower_nv_task_output(nir_builder *b,
45                      nir_instr *instr,
46                      void *state)
47 {
48    if (instr->type != nir_instr_type_intrinsic)
49       return false;
50 
51    lower_task_nv_state *s = (lower_task_nv_state *) state;
52    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
53 
54    switch (intrin->intrinsic) {
55    case nir_intrinsic_load_output: {
56       b->cursor = nir_after_instr(instr);
57       nir_ssa_def *load =
58          nir_load_shared(b, 1, 32, nir_imm_int(b, s->task_count_shared_addr));
59       nir_ssa_def_rewrite_uses(&intrin->dest.ssa, load);
60       nir_instr_remove(instr);
61       return true;
62    }
63 
64    case nir_intrinsic_store_output: {
65       b->cursor = nir_after_instr(instr);
66       nir_ssa_def *store_val = intrin->src[0].ssa;
67       nir_store_shared(b, store_val, nir_imm_int(b, s->task_count_shared_addr));
68       nir_instr_remove(instr);
69       return true;
70    }
71 
72    default:
73       return false;
74    }
75 }
76 
77 static void
append_launch_mesh_workgroups_to_nv_task(nir_builder *b, lower_task_nv_state *s)78 append_launch_mesh_workgroups_to_nv_task(nir_builder *b,
79                                          lower_task_nv_state *s)
80 {
81    /* At the beginning of the shader, write 0 to the task count.
82     * This ensures that 0 mesh workgroups are launched when the
83     * shader doesn't write the TASK_COUNT output.
84     */
85    b->cursor = nir_before_cf_list(&b->impl->body);
86    nir_ssa_def *zero = nir_imm_int(b, 0);
87    nir_store_shared(b, zero, nir_imm_int(b, s->task_count_shared_addr));
88 
89    nir_scoped_barrier(b,
90          .execution_scope = NIR_SCOPE_WORKGROUP,
91          .memory_scope = NIR_SCOPE_WORKGROUP,
92          .memory_semantics = NIR_MEMORY_RELEASE,
93          .memory_modes = nir_var_mem_shared);
94 
95    /* At the end of the shader, read the task count from shared memory
96     * and emit launch_mesh_workgroups.
97     */
98    b->cursor = nir_after_cf_list(&b->impl->body);
99 
100    nir_scoped_barrier(b,
101          .execution_scope = NIR_SCOPE_WORKGROUP,
102          .memory_scope = NIR_SCOPE_WORKGROUP,
103          .memory_semantics = NIR_MEMORY_ACQUIRE,
104          .memory_modes = nir_var_mem_shared);
105 
106    nir_ssa_def *task_count =
107       nir_load_shared(b, 1, 32, nir_imm_int(b, s->task_count_shared_addr));
108 
109    /* NV_mesh_shader doesn't offer to choose which task_payload variable
110     * should be passed to mesh shaders, we just pass all.
111     */
112    uint32_t range = b->shader->info.task_payload_size;
113 
114    nir_ssa_def *one = nir_imm_int(b, 1);
115    nir_ssa_def *dispatch_3d = nir_vec3(b, task_count, one, one);
116    nir_launch_mesh_workgroups(b, dispatch_3d, .base = 0, .range = range);
117 }
118 
119 /**
120  * For NV_mesh_shader:
121  * Task shaders only have 1 output, TASK_COUNT which is a 32-bit
122  * unsigned int that contains the 1-dimensional mesh dispatch size.
123  * This output should behave like a shared variable.
124  *
125  * We lower this output to a shared variable and then we emit
126  * the new launch_mesh_workgroups intrinsic at the end of the shader.
127  */
128 static void
nir_lower_nv_task_count(nir_shader *shader)129 nir_lower_nv_task_count(nir_shader *shader)
130 {
131    lower_task_nv_state state = {
132       .task_count_shared_addr = ALIGN(shader->info.shared_size, 4),
133    };
134 
135    shader->info.shared_size += 4;
136    nir_shader_instructions_pass(shader, lower_nv_task_output,
137                                 nir_metadata_none, &state);
138 
139    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
140    nir_builder builder;
141    nir_builder_init(&builder, impl);
142 
143    append_launch_mesh_workgroups_to_nv_task(&builder, &state);
144    nir_metadata_preserve(impl, nir_metadata_none);
145 }
146 
147 static nir_intrinsic_op
shared_opcode_for_task_payload(nir_intrinsic_op task_payload_op)148 shared_opcode_for_task_payload(nir_intrinsic_op task_payload_op)
149 {
150    switch (task_payload_op) {
151 #define OP(O) case nir_intrinsic_task_payload_##O: return nir_intrinsic_shared_##O;
152    OP(atomic_exchange)
153    OP(atomic_comp_swap)
154    OP(atomic_add)
155    OP(atomic_imin)
156    OP(atomic_umin)
157    OP(atomic_imax)
158    OP(atomic_umax)
159    OP(atomic_and)
160    OP(atomic_or)
161    OP(atomic_xor)
162    OP(atomic_fadd)
163    OP(atomic_fmin)
164    OP(atomic_fmax)
165    OP(atomic_fcomp_swap)
166 #undef OP
167    case nir_intrinsic_load_task_payload:
168       return nir_intrinsic_load_shared;
169    case nir_intrinsic_store_task_payload:
170       return nir_intrinsic_store_shared;
171    default:
172       unreachable("Invalid task payload atomic");
173    }
174 }
175 
176 static bool
lower_task_payload_to_shared(nir_builder *b, nir_intrinsic_instr *intrin, lower_task_state *s)177 lower_task_payload_to_shared(nir_builder *b,
178                              nir_intrinsic_instr *intrin,
179                              lower_task_state *s)
180 {
181    /* This assumes that shared and task_payload intrinsics
182     * have the same number of sources and same indices.
183     */
184    unsigned base = nir_intrinsic_base(intrin);
185    intrin->intrinsic = shared_opcode_for_task_payload(intrin->intrinsic);
186    nir_intrinsic_set_base(intrin, base + s->payload_shared_addr);
187 
188    return true;
189 }
190 
191 static void
emit_shared_to_payload_copy(nir_builder *b, uint32_t payload_addr, uint32_t payload_size, lower_task_state *s)192 emit_shared_to_payload_copy(nir_builder *b,
193                             uint32_t payload_addr,
194                             uint32_t payload_size,
195                             lower_task_state *s)
196 {
197    const unsigned invocations = b->shader->info.workgroup_size[0] *
198                           b->shader->info.workgroup_size[1] *
199                           b->shader->info.workgroup_size[2];
200    const unsigned bytes_per_copy = 16;
201    const unsigned copies_needed = DIV_ROUND_UP(payload_size, bytes_per_copy);
202    const unsigned copies_per_invocation = DIV_ROUND_UP(copies_needed, invocations);
203    const unsigned base_shared_addr = s->payload_shared_addr + payload_addr;
204 
205    nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
206    nir_ssa_def *addr = nir_imul_imm(b, invocation_index, bytes_per_copy);
207 
208    /* Wait for all previous shared stores to finish.
209     * This is necessary because we placed the payload in shared memory.
210     */
211    nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
212                          .memory_scope = NIR_SCOPE_WORKGROUP,
213                          .memory_semantics = NIR_MEMORY_ACQ_REL,
214                          .memory_modes = nir_var_mem_shared);
215 
216    for (unsigned i = 0; i < copies_per_invocation; ++i) {
217       unsigned const_off = bytes_per_copy * invocations * i;
218 
219       /* Read from shared memory. */
220       nir_ssa_def *copy =
221          nir_load_shared(b, 4, 32, addr, .align_mul = 16,
222                          .base = base_shared_addr + const_off);
223 
224       /* Write to task payload memory. */
225       nir_store_task_payload(b, copy, addr, .base = const_off);
226    }
227 }
228 
229 static bool
lower_task_launch_mesh_workgroups(nir_builder *b, nir_intrinsic_instr *intrin, lower_task_state *s)230 lower_task_launch_mesh_workgroups(nir_builder *b,
231                                   nir_intrinsic_instr *intrin,
232                                   lower_task_state *s)
233 {
234    if (s->payload_in_shared) {
235       /* Copy the payload from shared memory.
236        * Because launch_mesh_workgroups may only occur in
237        * workgroup-uniform control flow, here we assume that
238        * all invocations in the workgroup are active and therefore
239        * they can all participate in the copy.
240        *
241        * TODO: Skip the copy when the mesh dispatch size is (0, 0, 0).
242        *       This is problematic because the dispatch size can be divergent,
243        *       and may differ accross subgroups.
244        */
245 
246       uint32_t payload_addr = nir_intrinsic_base(intrin);
247       uint32_t payload_size = nir_intrinsic_range(intrin);
248 
249       b->cursor = nir_before_instr(&intrin->instr);
250       emit_shared_to_payload_copy(b, payload_addr, payload_size, s);
251    }
252 
253    /* The launch_mesh_workgroups intrinsic is a terminating instruction,
254     * so let's delete everything after it.
255     */
256    b->cursor = nir_after_instr(&intrin->instr);
257    nir_block *current_block = nir_cursor_current_block(b->cursor);
258 
259    /* Delete following instructions in the current block. */
260    nir_foreach_instr_reverse_safe(instr, current_block) {
261       if (instr == &intrin->instr)
262          break;
263       nir_instr_remove(instr);
264    }
265 
266    /* Delete following CF at the same level. */
267    b->cursor = nir_after_instr(&intrin->instr);
268    nir_cf_list extracted;
269    nir_cf_node *end_node = &current_block->cf_node;
270    while (!nir_cf_node_is_last(end_node))
271       end_node = nir_cf_node_next(end_node);
272    nir_cf_extract(&extracted, b->cursor, nir_after_cf_node(end_node));
273    nir_cf_delete(&extracted);
274 
275    /* Terminate the task shader. */
276    b->cursor = nir_after_instr(&intrin->instr);
277    nir_jump(b, nir_jump_return);
278 
279    return true;
280 }
281 
282 static bool
lower_task_intrin(nir_builder *b, nir_instr *instr, void *state)283 lower_task_intrin(nir_builder *b,
284                   nir_instr *instr,
285                   void *state)
286 {
287    if (instr->type != nir_instr_type_intrinsic)
288       return false;
289 
290    lower_task_state *s = (lower_task_state *) state;
291    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
292 
293    switch (intrin->intrinsic) {
294    case nir_intrinsic_task_payload_atomic_add:
295    case nir_intrinsic_task_payload_atomic_imin:
296    case nir_intrinsic_task_payload_atomic_umin:
297    case nir_intrinsic_task_payload_atomic_imax:
298    case nir_intrinsic_task_payload_atomic_umax:
299    case nir_intrinsic_task_payload_atomic_and:
300    case nir_intrinsic_task_payload_atomic_or:
301    case nir_intrinsic_task_payload_atomic_xor:
302    case nir_intrinsic_task_payload_atomic_exchange:
303    case nir_intrinsic_task_payload_atomic_comp_swap:
304    case nir_intrinsic_task_payload_atomic_fadd:
305    case nir_intrinsic_task_payload_atomic_fmin:
306    case nir_intrinsic_task_payload_atomic_fmax:
307    case nir_intrinsic_task_payload_atomic_fcomp_swap:
308    case nir_intrinsic_store_task_payload:
309    case nir_intrinsic_load_task_payload:
310       if (s->payload_in_shared)
311          return lower_task_payload_to_shared(b, intrin, s);
312       return false;
313    case nir_intrinsic_launch_mesh_workgroups:
314       return lower_task_launch_mesh_workgroups(b, intrin, s);
315    default:
316       return false;
317    }
318 }
319 
320 static bool
uses_task_payload_atomics(nir_shader *shader)321 uses_task_payload_atomics(nir_shader *shader)
322 {
323    nir_foreach_function(func, shader) {
324       if (!func->impl)
325          continue;
326 
327       nir_foreach_block(block, func->impl) {
328          nir_foreach_instr(instr, block) {
329             if (instr->type != nir_instr_type_intrinsic)
330                continue;
331 
332             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
333             switch (intrin->intrinsic) {
334                case nir_intrinsic_task_payload_atomic_add:
335                case nir_intrinsic_task_payload_atomic_imin:
336                case nir_intrinsic_task_payload_atomic_umin:
337                case nir_intrinsic_task_payload_atomic_imax:
338                case nir_intrinsic_task_payload_atomic_umax:
339                case nir_intrinsic_task_payload_atomic_and:
340                case nir_intrinsic_task_payload_atomic_or:
341                case nir_intrinsic_task_payload_atomic_xor:
342                case nir_intrinsic_task_payload_atomic_exchange:
343                case nir_intrinsic_task_payload_atomic_comp_swap:
344                case nir_intrinsic_task_payload_atomic_fadd:
345                case nir_intrinsic_task_payload_atomic_fmin:
346                case nir_intrinsic_task_payload_atomic_fmax:
347                case nir_intrinsic_task_payload_atomic_fcomp_swap:
348                   return true;
349                default:
350                   break;
351             }
352          }
353       }
354    }
355 
356    return false;
357 }
358 
359 /**
360  * Common Task Shader lowering to make the job of the backends easier.
361  *
362  * - Lowers NV_mesh_shader TASK_COUNT output to launch_mesh_workgroups.
363  * - Removes all code after launch_mesh_workgroups, enforcing the
364  *   fact that it's a terminating instruction.
365  * - Ensures that task shaders always have at least one
366  *   launch_mesh_workgroups instruction, so the backend doesn't
367  *   need to implement a special case when the shader doesn't have it.
368  * - Optionally, implements task_payload using shared memory when
369  *   task_payload atomics are used.
370  *   This is useful when the backend is otherwise not capable of
371  *   handling the same atomic features as it can for shared memory.
372  *   If this is used, the backend only has to implement the basic
373  *   load/store operations for task_payload.
374  *
375  * Note, this pass operates on lowered explicit I/O intrinsics, so
376  * it should be called after nir_lower_io + nir_lower_explicit_io.
377  */
378 bool
nir_lower_task_shader(nir_shader *shader, nir_lower_task_shader_options options)379 nir_lower_task_shader(nir_shader *shader,
380                       nir_lower_task_shader_options options)
381 {
382    if (shader->info.stage != MESA_SHADER_TASK)
383       return false;
384 
385    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
386    nir_builder builder;
387    nir_builder_init(&builder, impl);
388 
389    if (shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_TASK_COUNT)) {
390       /* NV_mesh_shader:
391        * If the shader writes TASK_COUNT, lower that to emit
392        * the new launch_mesh_workgroups intrinsic instead.
393        */
394       nir_lower_nv_task_count(shader);
395    } else {
396       /* To make sure that task shaders always have a code path that
397        * executes a launch_mesh_workgroups, let's add one at the end.
398        * If the shader already had a launch_mesh_workgroups by any chance,
399        * this will be removed.
400        */
401       nir_block *last_block = nir_impl_last_block(impl);
402       builder.cursor = nir_after_block_before_jump(last_block);
403       nir_launch_mesh_workgroups(&builder, nir_imm_zero(&builder, 3, 32));
404    }
405 
406    bool payload_in_shared = options.payload_to_shared_for_atomics &&
407                             uses_task_payload_atomics(shader);
408 
409    lower_task_state state = {
410       .payload_shared_addr = ALIGN(shader->info.shared_size, 16),
411       .payload_in_shared = payload_in_shared,
412    };
413 
414    if (payload_in_shared)
415       shader->info.shared_size =
416          state.payload_shared_addr + shader->info.task_payload_size;
417 
418    nir_shader_instructions_pass(shader, lower_task_intrin,
419                                 nir_metadata_none, &state);
420 
421    /* Delete all code that potentially can't be reached due to
422     * launch_mesh_workgroups being a terminating instruction.
423     */
424    nir_lower_returns(shader);
425    bool progress;
426    do {
427       progress = false;
428       NIR_PASS(progress, shader, nir_opt_dead_cf);
429       NIR_PASS(progress, shader, nir_opt_dce);
430    } while (progress);
431    return true;
432 }
433