1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "d3d12_compiler.h"
25#include "d3d12_context.h"
26#include "d3d12_debug.h"
27#include "d3d12_screen.h"
28
29#include "nir.h"
30#include "compiler/nir/nir_builder.h"
31#include "compiler/nir/nir_builtin_builder.h"
32
33#include "util/u_memory.h"
34#include "util/u_simple_shaders.h"
35
36static nir_ssa_def *
37nir_cull_face(nir_builder *b, nir_variable *vertices, bool ccw)
38{
39   nir_ssa_def *v0 =
40       nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, vertices), nir_imm_int(b, 0)));
41   nir_ssa_def *v1 =
42       nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, vertices), nir_imm_int(b, 1)));
43   nir_ssa_def *v2 =
44       nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, vertices), nir_imm_int(b, 2)));
45
46   nir_ssa_def *dir = nir_fdot(b, nir_cross4(b, nir_fsub(b, v1, v0),
47                                               nir_fsub(b, v2, v0)),
48                                   nir_imm_vec4(b, 0.0, 0.0, -1.0, 0.0));
49   if (ccw)
50       return nir_fge(b, nir_imm_int(b, 0), dir);
51   else
52       return nir_flt(b, nir_imm_int(b, 0), dir);
53}
54
55static void
56copy_vars(nir_builder *b, nir_deref_instr *dst, nir_deref_instr *src)
57{
58   assert(glsl_get_bare_type(dst->type) == glsl_get_bare_type(src->type));
59   if (glsl_type_is_struct(dst->type)) {
60      for (unsigned i = 0; i < glsl_get_length(dst->type); ++i) {
61         copy_vars(b, nir_build_deref_struct(b, dst, i), nir_build_deref_struct(b, src, i));
62      }
63   } else if (glsl_type_is_array_or_matrix(dst->type)) {
64      copy_vars(b, nir_build_deref_array_wildcard(b, dst), nir_build_deref_array_wildcard(b, src));
65   } else {
66      nir_copy_deref(b, dst, src);
67   }
68}
69
70static d3d12_shader_selector*
71d3d12_make_passthrough_gs(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
72{
73   struct d3d12_shader_selector *gs;
74   uint64_t varyings = key->varyings.mask;
75   nir_shader *nir;
76   struct pipe_shader_state templ;
77
78   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_GEOMETRY,
79                                                  &d3d12_screen(ctx->base.screen)->nir_options,
80                                                  "passthrough");
81
82   nir = b.shader;
83   nir->info.inputs_read = varyings;
84   nir->info.outputs_written = varyings;
85   nir->info.gs.input_primitive = GL_POINTS;
86   nir->info.gs.output_primitive = GL_POINTS;
87   nir->info.gs.vertices_in = 1;
88   nir->info.gs.vertices_out = 1;
89   nir->info.gs.invocations = 1;
90   nir->info.gs.active_stream_mask = 1;
91
92   /* Copy inputs to outputs. */
93   while (varyings) {
94      char tmp[100];
95      const int i = u_bit_scan64(&varyings);
96
97      unsigned frac_slots = key->varyings.slots[i].location_frac_mask;
98      while (frac_slots) {
99         nir_variable *in, *out;
100         int j = u_bit_scan(&frac_slots);
101
102         snprintf(tmp, ARRAY_SIZE(tmp), "in_%d", key->varyings.slots[i].vars[j].driver_location);
103         in = nir_variable_create(nir,
104                                  nir_var_shader_in,
105                                  glsl_array_type(key->varyings.slots[i].types[j], 1, false),
106                                  tmp);
107         in->data.location = i;
108         in->data.location_frac = j;
109         in->data.driver_location = key->varyings.slots[i].vars[j].driver_location;
110         in->data.interpolation = key->varyings.slots[i].vars[j].interpolation;
111         in->data.compact = key->varyings.slots[i].vars[j].compact;
112
113         snprintf(tmp, ARRAY_SIZE(tmp), "out_%d", key->varyings.slots[i].vars[j].driver_location);
114         out = nir_variable_create(nir,
115                                   nir_var_shader_out,
116                                   key->varyings.slots[i].types[j],
117                                   tmp);
118         out->data.location = i;
119         out->data.location_frac = j;
120         out->data.driver_location = key->varyings.slots[i].vars[j].driver_location;
121         out->data.interpolation = key->varyings.slots[i].vars[j].interpolation;
122         out->data.compact = key->varyings.slots[i].vars[j].compact;
123
124         nir_deref_instr *in_value = nir_build_deref_array(&b, nir_build_deref_var(&b, in),
125                                                               nir_imm_int(&b, 0));
126         copy_vars(&b, nir_build_deref_var(&b, out), in_value);
127      }
128   }
129
130   nir_emit_vertex(&b, 0);
131   nir_end_primitive(&b, 0);
132
133   NIR_PASS_V(nir, nir_lower_var_copies);
134   nir_validate_shader(nir, "in d3d12_create_passthrough_gs");
135
136   templ.type = PIPE_SHADER_IR_NIR;
137   templ.ir.nir = nir;
138   templ.stream_output.num_outputs = 0;
139
140   gs = d3d12_create_shader(ctx, PIPE_SHADER_GEOMETRY, &templ);
141
142   return gs;
143}
144
145struct emit_primitives_context
146{
147   struct d3d12_context *ctx;
148   nir_builder b;
149
150   unsigned num_vars;
151   nir_variable *in[VARYING_SLOT_MAX];
152   nir_variable *out[VARYING_SLOT_MAX];
153   nir_variable *front_facing_var;
154
155   nir_loop *loop;
156   nir_deref_instr *loop_index_deref;
157   nir_ssa_def *loop_index;
158   nir_ssa_def *edgeflag_cmp;
159   nir_ssa_def *front_facing;
160};
161
162static bool
163d3d12_begin_emit_primitives_gs(struct emit_primitives_context *emit_ctx,
164                               struct d3d12_context *ctx,
165                               struct d3d12_gs_variant_key *key,
166                               uint16_t output_primitive,
167                               unsigned vertices_out)
168{
169   nir_builder *b = &emit_ctx->b;
170   nir_variable *edgeflag_var = NULL;
171   nir_variable *pos_var = NULL;
172   uint64_t varyings = key->varyings.mask;
173
174   emit_ctx->ctx = ctx;
175
176   emit_ctx->b = nir_builder_init_simple_shader(MESA_SHADER_GEOMETRY,
177                                                &d3d12_screen(ctx->base.screen)->nir_options,
178                                                "edgeflags");
179
180   nir_shader *nir = b->shader;
181   nir->info.inputs_read = varyings;
182   nir->info.outputs_written = varyings;
183   nir->info.gs.input_primitive = GL_TRIANGLES;
184   nir->info.gs.output_primitive = output_primitive;
185   nir->info.gs.vertices_in = 3;
186   nir->info.gs.vertices_out = vertices_out;
187   nir->info.gs.invocations = 1;
188   nir->info.gs.active_stream_mask = 1;
189
190   while (varyings) {
191      char tmp[100];
192      const int i = u_bit_scan64(&varyings);
193
194      unsigned frac_slots = key->varyings.slots[i].location_frac_mask;
195      while (frac_slots) {
196         int j = u_bit_scan(&frac_slots);
197         snprintf(tmp, ARRAY_SIZE(tmp), "in_%d", emit_ctx->num_vars);
198         emit_ctx->in[emit_ctx->num_vars] = nir_variable_create(nir,
199                                                                nir_var_shader_in,
200                                                                glsl_array_type(key->varyings.slots[i].types[j], 3, 0),
201                                                                tmp);
202         emit_ctx->in[emit_ctx->num_vars]->data.location = i;
203         emit_ctx->in[emit_ctx->num_vars]->data.location_frac = j;
204         emit_ctx->in[emit_ctx->num_vars]->data.driver_location = key->varyings.slots[i].vars[j].driver_location;
205         emit_ctx->in[emit_ctx->num_vars]->data.interpolation = key->varyings.slots[i].vars[j].interpolation;
206         emit_ctx->in[emit_ctx->num_vars]->data.compact = key->varyings.slots[i].vars[j].compact;
207
208         /* Don't create an output for the edge flag variable */
209         if (i == VARYING_SLOT_EDGE) {
210            edgeflag_var = emit_ctx->in[emit_ctx->num_vars];
211            continue;
212         } else if (i == VARYING_SLOT_POS) {
213             pos_var = emit_ctx->in[emit_ctx->num_vars];
214         }
215
216         snprintf(tmp, ARRAY_SIZE(tmp), "out_%d", emit_ctx->num_vars);
217         emit_ctx->out[emit_ctx->num_vars] = nir_variable_create(nir,
218                                                                 nir_var_shader_out,
219                                                                 key->varyings.slots[i].types[j],
220                                                                 tmp);
221         emit_ctx->out[emit_ctx->num_vars]->data.location = i;
222         emit_ctx->out[emit_ctx->num_vars]->data.location_frac = j;
223         emit_ctx->out[emit_ctx->num_vars]->data.driver_location = key->varyings.slots[i].vars[j].driver_location;
224         emit_ctx->out[emit_ctx->num_vars]->data.interpolation = key->varyings.slots[i].vars[j].interpolation;
225         emit_ctx->out[emit_ctx->num_vars]->data.compact = key->varyings.slots[i].vars[j].compact;
226
227         emit_ctx->num_vars++;
228      }
229   }
230
231   if (key->has_front_face) {
232      emit_ctx->front_facing_var = nir_variable_create(nir,
233                                                       nir_var_shader_out,
234                                                       glsl_uint_type(),
235                                                       "gl_FrontFacing");
236      emit_ctx->front_facing_var->data.location = VARYING_SLOT_VAR12;
237      emit_ctx->front_facing_var->data.driver_location = emit_ctx->num_vars;
238      emit_ctx->front_facing_var->data.interpolation = INTERP_MODE_FLAT;
239   }
240
241   /* Temporary variable "loop_index" to loop over input vertices */
242   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
243   nir_variable *loop_index_var =
244      nir_local_variable_create(impl, glsl_uint_type(), "loop_index");
245   emit_ctx->loop_index_deref = nir_build_deref_var(b, loop_index_var);
246   nir_store_deref(b, emit_ctx->loop_index_deref, nir_imm_int(b, 0), 1);
247
248   nir_ssa_def *diagonal_vertex = NULL;
249   if (key->edge_flag_fix) {
250      nir_ssa_def *prim_id = nir_load_primitive_id(b);
251      nir_ssa_def *odd = nir_build_alu(b, nir_op_imod,
252                                       prim_id,
253                                       nir_imm_int(b, 2),
254                                       NULL, NULL);
255      diagonal_vertex = nir_bcsel(b, nir_i2b(b, odd),
256                                  nir_imm_int(b, 2),
257                                  nir_imm_int(b, 1));
258   }
259
260   if (key->cull_mode != PIPE_FACE_NONE || key->has_front_face) {
261      if (key->cull_mode == PIPE_FACE_BACK)
262         emit_ctx->edgeflag_cmp = nir_cull_face(b, pos_var, key->front_ccw);
263      else if (key->cull_mode == PIPE_FACE_FRONT)
264         emit_ctx->edgeflag_cmp = nir_cull_face(b, pos_var, !key->front_ccw);
265
266      if (key->has_front_face) {
267         if (key->cull_mode == PIPE_FACE_BACK)
268            emit_ctx->front_facing = emit_ctx->edgeflag_cmp;
269         else
270            emit_ctx->front_facing = nir_cull_face(b, pos_var, key->front_ccw);
271         emit_ctx->front_facing = nir_i2i32(b, emit_ctx->front_facing);
272      }
273   }
274
275   /**
276    *  while {
277    *     if (loop_index >= 3)
278    *        break;
279    */
280   emit_ctx->loop = nir_push_loop(b);
281
282   emit_ctx->loop_index = nir_load_deref(b, emit_ctx->loop_index_deref);
283   nir_ssa_def *cmp = nir_ige(b, emit_ctx->loop_index,
284                              nir_imm_int(b, 3));
285   nir_if *loop_check = nir_push_if(b, cmp);
286   nir_jump(b, nir_jump_break);
287   nir_pop_if(b, loop_check);
288
289   if (edgeflag_var) {
290      nir_ssa_def *edge_flag =
291         nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, edgeflag_var), emit_ctx->loop_index));
292      nir_ssa_def *is_edge = nir_feq(b, nir_channel(b, edge_flag, 0), nir_imm_float(b, 1.0));
293      if (emit_ctx->edgeflag_cmp)
294         emit_ctx->edgeflag_cmp = nir_iand(b, emit_ctx->edgeflag_cmp, is_edge);
295      else
296         emit_ctx->edgeflag_cmp = is_edge;
297   }
298
299   if (key->edge_flag_fix) {
300      nir_ssa_def *is_edge = nir_ine(b, emit_ctx->loop_index, diagonal_vertex);
301      if (emit_ctx->edgeflag_cmp)
302         emit_ctx->edgeflag_cmp = nir_iand(b, emit_ctx->edgeflag_cmp, is_edge);
303      else
304         emit_ctx->edgeflag_cmp = is_edge;
305   }
306
307   return true;
308}
309
310static struct d3d12_shader_selector *
311d3d12_finish_emit_primitives_gs(struct emit_primitives_context *emit_ctx, bool end_primitive)
312{
313   struct pipe_shader_state templ;
314   nir_builder *b = &emit_ctx->b;
315   nir_shader *nir = b->shader;
316
317   /**
318    *     loop_index++;
319    *  }
320    */
321   nir_store_deref(b, emit_ctx->loop_index_deref, nir_iadd_imm(b, emit_ctx->loop_index, 1), 1);
322   nir_pop_loop(b, emit_ctx->loop);
323
324   if (end_primitive)
325      nir_end_primitive(b, 0);
326
327   nir_validate_shader(nir, "in d3d12_lower_edge_flags");
328
329   NIR_PASS_V(nir, nir_lower_var_copies);
330
331   templ.type = PIPE_SHADER_IR_NIR;
332   templ.ir.nir = nir;
333   templ.stream_output.num_outputs = 0;
334
335   return d3d12_create_shader(emit_ctx->ctx, PIPE_SHADER_GEOMETRY, &templ);
336}
337
338static d3d12_shader_selector*
339d3d12_emit_points(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
340{
341   struct emit_primitives_context emit_ctx = {0};
342   nir_builder *b = &emit_ctx.b;
343
344   d3d12_begin_emit_primitives_gs(&emit_ctx, ctx, key, GL_POINTS, 3);
345
346   /**
347    *  if (edge_flag)
348    *     out_position = in_position;
349    *  else
350    *     out_position = vec4(-2.0, -2.0, 0.0, 1.0); // Invalid position
351    *
352    *  [...] // Copy other variables
353    *
354    *  EmitVertex();
355    */
356   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
357      nir_ssa_def *index = (key->flat_varyings & (1ull << emit_ctx.in[i]->data.location))  ?
358                              nir_imm_int(b, (key->flatshade_first ? 0 : 2)) : emit_ctx.loop_index;
359      nir_deref_instr *in_value = nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index);
360      if (emit_ctx.in[i]->data.location == VARYING_SLOT_POS && emit_ctx.edgeflag_cmp) {
361         nir_if *edge_check = nir_push_if(b, emit_ctx.edgeflag_cmp);
362         copy_vars(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
363         nir_if *edge_else = nir_push_else(b, edge_check);
364         nir_store_deref(b, nir_build_deref_var(b, emit_ctx.out[i]),
365                         nir_imm_vec4(b, -2.0, -2.0, 0.0, 1.0), 0xf);
366         nir_pop_if(b, edge_else);
367      } else {
368         copy_vars(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
369      }
370   }
371   if (key->has_front_face)
372       nir_store_var(b, emit_ctx.front_facing_var, emit_ctx.front_facing, 0x1);
373   nir_emit_vertex(b, 0);
374
375   return d3d12_finish_emit_primitives_gs(&emit_ctx, false);
376}
377
378static d3d12_shader_selector*
379d3d12_emit_lines(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
380{
381   struct emit_primitives_context emit_ctx = {0};
382   nir_builder *b = &emit_ctx.b;
383
384   d3d12_begin_emit_primitives_gs(&emit_ctx, ctx, key, GL_LINE_STRIP, 6);
385
386   nir_ssa_def *next_index = nir_imod(b, nir_iadd_imm(b, emit_ctx.loop_index, 1), nir_imm_int(b, 3));
387
388   /* First vertex */
389   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
390      nir_ssa_def *index = (key->flat_varyings & (1ull << emit_ctx.in[i]->data.location)) ?
391                              nir_imm_int(b, (key->flatshade_first ? 0 : 2)) : emit_ctx.loop_index;
392      nir_deref_instr *in_value = nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index);
393      copy_vars(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
394   }
395   if (key->has_front_face)
396       nir_store_var(b, emit_ctx.front_facing_var, emit_ctx.front_facing, 0x1);
397   nir_emit_vertex(b, 0);
398
399   /* Second vertex. If not an edge, use same position as first vertex */
400   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
401      nir_ssa_def *index = next_index;
402      if (emit_ctx.in[i]->data.location == VARYING_SLOT_POS)
403         index = nir_bcsel(b, emit_ctx.edgeflag_cmp, next_index, emit_ctx.loop_index);
404      else if (key->flat_varyings & (1ull << emit_ctx.in[i]->data.location))
405         index = nir_imm_int(b, 2);
406      copy_vars(b, nir_build_deref_var(b, emit_ctx.out[i]),
407                nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index));
408   }
409   if (key->has_front_face)
410       nir_store_var(b, emit_ctx.front_facing_var, emit_ctx.front_facing, 0x1);
411   nir_emit_vertex(b, 0);
412
413   nir_end_primitive(b, 0);
414
415   return d3d12_finish_emit_primitives_gs(&emit_ctx, false);
416}
417
418static d3d12_shader_selector*
419d3d12_emit_triangles(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
420{
421   struct emit_primitives_context emit_ctx = {0};
422   nir_builder *b = &emit_ctx.b;
423
424   d3d12_begin_emit_primitives_gs(&emit_ctx, ctx, key, GL_TRIANGLE_STRIP, 3);
425
426   /**
427    *  [...] // Copy variables
428    *
429    *  EmitVertex();
430    */
431
432   nir_ssa_def *incr = NULL;
433
434   if (key->provoking_vertex > 0)
435      incr = nir_imm_int(b, key->provoking_vertex);
436   else
437      incr = nir_imm_int(b, 3);
438
439   if (key->alternate_tri) {
440      nir_ssa_def *odd = nir_imod(b, nir_load_primitive_id(b), nir_imm_int(b, 2));
441      incr = nir_isub(b, incr, odd);
442   }
443
444   assert(incr != NULL);
445   nir_ssa_def *index = nir_imod(b, nir_iadd(b, emit_ctx.loop_index, incr), nir_imm_int(b, 3));
446   for (unsigned i = 0; i < emit_ctx.num_vars; ++i) {
447      nir_deref_instr *in_value = nir_build_deref_array(b, nir_build_deref_var(b, emit_ctx.in[i]), index);
448      copy_vars(b, nir_build_deref_var(b, emit_ctx.out[i]), in_value);
449   }
450   nir_emit_vertex(b, 0);
451
452   return d3d12_finish_emit_primitives_gs(&emit_ctx, true);
453}
454
455static uint32_t
456hash_gs_variant_key(const void *key)
457{
458   return _mesa_hash_data(key, sizeof(struct d3d12_gs_variant_key));
459}
460
461static bool
462equals_gs_variant_key(const void *a, const void *b)
463{
464   return memcmp(a, b, sizeof(struct d3d12_gs_variant_key)) == 0;
465}
466
467void
468d3d12_gs_variant_cache_init(struct d3d12_context *ctx)
469{
470   ctx->gs_variant_cache = _mesa_hash_table_create(NULL, NULL, equals_gs_variant_key);
471}
472
473static void
474delete_entry(struct hash_entry *entry)
475{
476   d3d12_shader_free((d3d12_shader_selector *)entry->data);
477}
478
479void
480d3d12_gs_variant_cache_destroy(struct d3d12_context *ctx)
481{
482   _mesa_hash_table_destroy(ctx->gs_variant_cache, delete_entry);
483}
484
485static struct d3d12_shader_selector *
486create_geometry_shader_variant(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
487{
488   d3d12_shader_selector *gs = NULL;
489
490   if (key->passthrough)
491      gs = d3d12_make_passthrough_gs(ctx, key);
492   else if (key->provoking_vertex > 0 || key->alternate_tri)
493      gs = d3d12_emit_triangles(ctx, key);
494   else if (key->fill_mode == PIPE_POLYGON_MODE_POINT)
495      gs = d3d12_emit_points(ctx, key);
496   else if (key->fill_mode == PIPE_POLYGON_MODE_LINE)
497      gs = d3d12_emit_lines(ctx, key);
498
499   if (gs) {
500      gs->is_variant = true;
501      gs->gs_key = *key;
502   }
503
504   return gs;
505}
506
507d3d12_shader_selector *
508d3d12_get_gs_variant(struct d3d12_context *ctx, struct d3d12_gs_variant_key *key)
509{
510   uint32_t hash = hash_gs_variant_key(key);
511   struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->gs_variant_cache,
512                                                                 hash, key);
513   if (!entry) {
514      d3d12_shader_selector *gs = create_geometry_shader_variant(ctx, key);
515      entry = _mesa_hash_table_insert_pre_hashed(ctx->gs_variant_cache,
516                                                 hash, &gs->gs_key, gs);
517      assert(entry);
518   }
519
520   return (d3d12_shader_selector *)entry->data;
521}
522