1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "d3d12_compiler.h"
25#include "d3d12_context.h"
26#include "d3d12_debug.h"
27#include "d3d12_screen.h"
28#include "d3d12_nir_passes.h"
29#include "nir_to_dxil.h"
30#include "dxil_nir.h"
31#include "dxil_nir_lower_int_cubemaps.h"
32
33#include "pipe/p_state.h"
34
35#include "nir.h"
36#include "nir/nir_draw_helpers.h"
37#include "nir/tgsi_to_nir.h"
38#include "compiler/nir/nir_builder.h"
39#include "tgsi/tgsi_from_mesa.h"
40#include "tgsi/tgsi_ureg.h"
41
42#include "util/hash_table.h"
43#include "util/u_memory.h"
44#include "util/u_prim.h"
45#include "util/u_simple_shaders.h"
46#include "util/u_dl.h"
47
48#include <dxguids/dxguids.h>
49
50extern "C" {
51#include "tgsi/tgsi_parse.h"
52#include "tgsi/tgsi_point_sprite.h"
53}
54
55#ifdef _WIN32
56#include "dxil_validator.h"
57#endif
58
59const void *
60d3d12_get_compiler_options(struct pipe_screen *screen,
61                           enum pipe_shader_ir ir,
62                           enum pipe_shader_type shader)
63{
64   assert(ir == PIPE_SHADER_IR_NIR);
65   return &d3d12_screen(screen)->nir_options;
66}
67
68static uint32_t
69resource_dimension(enum glsl_sampler_dim dim)
70{
71   switch (dim) {
72   case GLSL_SAMPLER_DIM_1D:
73      return RESOURCE_DIMENSION_TEXTURE1D;
74   case GLSL_SAMPLER_DIM_2D:
75      return RESOURCE_DIMENSION_TEXTURE2D;
76   case GLSL_SAMPLER_DIM_3D:
77      return RESOURCE_DIMENSION_TEXTURE3D;
78   case GLSL_SAMPLER_DIM_CUBE:
79      return RESOURCE_DIMENSION_TEXTURECUBE;
80   default:
81      return RESOURCE_DIMENSION_UNKNOWN;
82   }
83}
84
85static bool
86can_remove_dead_sampler(nir_variable *var, void *data)
87{
88   const struct glsl_type *base_type = glsl_without_array(var->type);
89   return glsl_type_is_sampler(base_type) && !glsl_type_is_bare_sampler(base_type);
90}
91
92static struct d3d12_shader *
93compile_nir(struct d3d12_context *ctx, struct d3d12_shader_selector *sel,
94            struct d3d12_shader_key *key, struct nir_shader *nir)
95{
96   struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
97   struct d3d12_shader *shader = rzalloc(sel, d3d12_shader);
98   shader->key = *key;
99   shader->nir = nir;
100   sel->current = shader;
101
102   NIR_PASS_V(nir, nir_lower_samplers);
103   NIR_PASS_V(nir, dxil_nir_split_typed_samplers);
104
105   NIR_PASS_V(nir, nir_opt_dce);
106   struct nir_remove_dead_variables_options dead_var_opts = {};
107   dead_var_opts.can_remove_var = can_remove_dead_sampler;
108   NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_uniform, &dead_var_opts);
109
110   if (key->samples_int_textures)
111      NIR_PASS_V(nir, dxil_lower_sample_to_txf_for_integer_tex,
112                 key->tex_wrap_states, key->swizzle_state,
113                 screen->base.get_paramf(&screen->base, PIPE_CAPF_MAX_TEXTURE_LOD_BIAS));
114
115   if (key->vs.needs_format_emulation)
116      dxil_nir_lower_vs_vertex_conversion(nir, key->vs.format_conversion);
117
118   uint32_t num_ubos_before_lower_to_ubo = nir->info.num_ubos;
119   uint32_t num_uniforms_before_lower_to_ubo = nir->num_uniforms;
120   NIR_PASS_V(nir, nir_lower_uniforms_to_ubo, false, false);
121   shader->has_default_ubo0 = num_uniforms_before_lower_to_ubo > 0 &&
122                              nir->info.num_ubos > num_ubos_before_lower_to_ubo;
123
124   if (key->last_vertex_processing_stage) {
125      if (key->invert_depth)
126         NIR_PASS_V(nir, d3d12_nir_invert_depth, key->invert_depth, key->halfz);
127      if (!key->halfz)
128         NIR_PASS_V(nir, nir_lower_clip_halfz);
129      NIR_PASS_V(nir, d3d12_lower_yflip);
130   }
131   NIR_PASS_V(nir, nir_lower_packed_ubo_loads);
132   NIR_PASS_V(nir, d3d12_lower_load_draw_params);
133   NIR_PASS_V(nir, d3d12_lower_load_patch_vertices_in);
134   NIR_PASS_V(nir, d3d12_lower_state_vars, shader);
135   NIR_PASS_V(nir, dxil_nir_lower_bool_input);
136   NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
137   NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
138   NIR_PASS_V(nir, dxil_nir_lower_double_math);
139
140   if (key->fs.multisample_disabled)
141      NIR_PASS_V(nir, d3d12_disable_multisampling);
142
143   struct nir_to_dxil_options opts = {};
144   opts.interpolate_at_vertex = screen->have_load_at_vertex;
145   opts.lower_int16 = !screen->opts4.Native16BitShaderOpsSupported;
146   opts.no_ubo0 = !shader->has_default_ubo0;
147   opts.last_ubo_is_not_arrayed = shader->num_state_vars > 0;
148   opts.provoking_vertex = key->fs.provoking_vertex;
149   opts.input_clip_size = key->input_clip_size;
150   opts.environment = DXIL_ENVIRONMENT_GL;
151   opts.shader_model_max = SHADER_MODEL_6_2;
152#ifdef _WIN32
153   opts.validator_version_max = dxil_get_validator_version(ctx->dxil_validator);
154#endif
155
156   struct blob tmp;
157   if (!nir_to_dxil(nir, &opts, &tmp)) {
158      debug_printf("D3D12: nir_to_dxil failed\n");
159      return NULL;
160   }
161
162   // Non-ubo variables
163   shader->begin_srv_binding = (UINT_MAX);
164   nir_foreach_variable_with_modes(var, nir, nir_var_uniform) {
165      auto type_no_array = glsl_without_array(var->type);
166      if (glsl_type_is_texture(type_no_array)) {
167         unsigned count = glsl_type_is_array(var->type) ? glsl_get_aoa_size(var->type) : 1;
168         for (unsigned i = 0; i < count; ++i) {
169            shader->srv_bindings[var->data.binding + i].dimension = resource_dimension(glsl_get_sampler_dim(type_no_array));
170         }
171         shader->begin_srv_binding = MIN2(var->data.binding, shader->begin_srv_binding);
172         shader->end_srv_binding = MAX2(var->data.binding + count, shader->end_srv_binding);
173      }
174   }
175
176   nir_foreach_image_variable(var, nir) {
177      auto type_no_array = glsl_without_array(var->type);
178      unsigned count = glsl_type_is_array(var->type) ? glsl_get_aoa_size(var->type) : 1;
179      for (unsigned i = 0; i < count; ++i) {
180         shader->uav_bindings[var->data.driver_location + i].format = var->data.image.format;
181         shader->uav_bindings[var->data.driver_location + i].dimension = resource_dimension(glsl_get_sampler_dim(type_no_array));
182      }
183   }
184
185   // Ubo variables
186   if(nir->info.num_ubos) {
187      // Ignore state_vars ubo as it is bound as root constants
188      unsigned num_ubo_bindings = nir->info.num_ubos - (shader->state_vars_used ? 1 : 0);
189      for(unsigned i = shader->has_default_ubo0 ? 0 : 1; i < num_ubo_bindings; ++i) {
190         shader->cb_bindings[shader->num_cb_bindings++].binding = i;
191      }
192   }
193
194#ifdef _WIN32
195   if (ctx->dxil_validator) {
196      if (!(d3d12_debug & D3D12_DEBUG_EXPERIMENTAL)) {
197         char *err;
198         if (!dxil_validate_module(ctx->dxil_validator, tmp.data,
199                                   tmp.size, &err) && err) {
200            debug_printf(
201               "== VALIDATION ERROR =============================================\n"
202               "%s\n"
203               "== END ==========================================================\n",
204               err);
205            ralloc_free(err);
206         }
207      }
208
209      if (d3d12_debug & D3D12_DEBUG_DISASS) {
210         char *str = dxil_disasm_module(ctx->dxil_validator, tmp.data,
211                                        tmp.size);
212         fprintf(stderr,
213                 "== BEGIN SHADER ============================================\n"
214                 "%s\n"
215                 "== END SHADER ==============================================\n",
216               str);
217         ralloc_free(str);
218      }
219   }
220#endif
221
222   blob_finish_get_buffer(&tmp, &shader->bytecode, &shader->bytecode_length);
223
224   if (d3d12_debug & D3D12_DEBUG_DXIL) {
225      char buf[256];
226      static int i;
227      snprintf(buf, sizeof(buf), "dump%02d.dxil", i++);
228      FILE *fp = fopen(buf, "wb");
229      fwrite(shader->bytecode, sizeof(char), shader->bytecode_length, fp);
230      fclose(fp);
231      fprintf(stderr, "wrote '%s'...\n", buf);
232   }
233   return shader;
234}
235
236struct d3d12_selection_context {
237   struct d3d12_context *ctx;
238   bool needs_point_sprite_lowering;
239   bool needs_vertex_reordering;
240   unsigned provoking_vertex;
241   bool alternate_tri;
242   unsigned fill_mode_lowered;
243   unsigned cull_mode_lowered;
244   bool manual_depth_range;
245   unsigned missing_dual_src_outputs;
246   unsigned frag_result_color_lowering;
247   const unsigned *variable_workgroup_size;
248};
249
250static unsigned
251missing_dual_src_outputs(struct d3d12_context *ctx)
252{
253   if (!ctx->gfx_pipeline_state.blend->is_dual_src)
254      return 0;
255
256   struct d3d12_shader_selector *fs = ctx->gfx_stages[PIPE_SHADER_FRAGMENT];
257   nir_shader *s = fs->initial;
258
259   unsigned indices_seen = 0;
260   nir_foreach_function(function, s) {
261      if (function->impl) {
262         nir_foreach_block(block, function->impl) {
263            nir_foreach_instr(instr, block) {
264               if (instr->type != nir_instr_type_intrinsic)
265                  continue;
266
267               nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
268               if (intr->intrinsic != nir_intrinsic_store_deref)
269                  continue;
270
271               nir_variable *var = nir_intrinsic_get_var(intr, 0);
272               if (var->data.mode != nir_var_shader_out)
273                  continue;
274
275               unsigned index = var->data.index;
276               if (var->data.location > FRAG_RESULT_DATA0)
277                  index = var->data.location - FRAG_RESULT_DATA0;
278               else if (var->data.location != FRAG_RESULT_COLOR &&
279                        var->data.location != FRAG_RESULT_DATA0)
280                  continue;
281
282               indices_seen |= 1u << index;
283               if ((indices_seen & 3) == 3)
284                  return 0;
285            }
286         }
287      }
288   }
289
290   return 3 & ~indices_seen;
291}
292
293static unsigned
294frag_result_color_lowering(struct d3d12_context *ctx)
295{
296   struct d3d12_shader_selector *fs = ctx->gfx_stages[PIPE_SHADER_FRAGMENT];
297   assert(fs);
298
299   if (fs->initial->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_COLOR))
300      return ctx->fb.nr_cbufs > 1 ? ctx->fb.nr_cbufs : 0;
301
302   return 0;
303}
304
305static bool
306manual_depth_range(struct d3d12_context *ctx)
307{
308   if (!d3d12_need_zero_one_depth_range(ctx))
309      return false;
310
311   /**
312    * If we can't use the D3D12 zero-one depth-range, we might have to apply
313    * depth-range ourselves.
314    *
315    * Because we only need to override the depth-range to zero-one range in
316    * the case where we write frag-depth, we only need to apply manual
317    * depth-range to gl_FragCoord.z.
318    *
319    * No extra care is needed to be taken in the case where gl_FragDepth is
320    * written conditionally, because the GLSL 4.60 spec states:
321    *
322    *    If a shader statically assigns a value to gl_FragDepth, and there
323    *    is an execution path through the shader that does not set
324    *    gl_FragDepth, then the value of the fragment’s depth may be
325    *    undefined for executions of the shader that take that path. That
326    *    is, if the set of linked fragment shaders statically contain a
327    *    write to gl_FragDepth, then it is responsible for always writing
328    *    it.
329    */
330
331   struct d3d12_shader_selector *fs = ctx->gfx_stages[PIPE_SHADER_FRAGMENT];
332   return fs && fs->initial->info.inputs_read & VARYING_BIT_POS;
333}
334
335static bool
336needs_edge_flag_fix(enum pipe_prim_type mode)
337{
338   return (mode == PIPE_PRIM_QUADS ||
339           mode == PIPE_PRIM_QUAD_STRIP ||
340           mode == PIPE_PRIM_POLYGON);
341}
342
343static unsigned
344fill_mode_lowered(struct d3d12_context *ctx, const struct pipe_draw_info *dinfo)
345{
346   struct d3d12_shader_selector *vs = ctx->gfx_stages[PIPE_SHADER_VERTEX];
347
348   if ((ctx->gfx_stages[PIPE_SHADER_GEOMETRY] != NULL &&
349        !ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->is_variant) ||
350       ctx->gfx_pipeline_state.rast == NULL ||
351       (dinfo->mode != PIPE_PRIM_TRIANGLES &&
352        dinfo->mode != PIPE_PRIM_TRIANGLE_STRIP))
353      return PIPE_POLYGON_MODE_FILL;
354
355   /* D3D12 supports line mode (wireframe) but doesn't support edge flags */
356   if (((ctx->gfx_pipeline_state.rast->base.fill_front == PIPE_POLYGON_MODE_LINE &&
357         ctx->gfx_pipeline_state.rast->base.cull_face != PIPE_FACE_FRONT) ||
358        (ctx->gfx_pipeline_state.rast->base.fill_back == PIPE_POLYGON_MODE_LINE &&
359         ctx->gfx_pipeline_state.rast->base.cull_face == PIPE_FACE_FRONT)) &&
360       (vs->initial->info.outputs_written & VARYING_BIT_EDGE ||
361        needs_edge_flag_fix(ctx->initial_api_prim)))
362      return PIPE_POLYGON_MODE_LINE;
363
364   if (ctx->gfx_pipeline_state.rast->base.fill_front == PIPE_POLYGON_MODE_POINT)
365      return PIPE_POLYGON_MODE_POINT;
366
367   return PIPE_POLYGON_MODE_FILL;
368}
369
370static bool
371has_stream_out_for_streams(struct d3d12_context *ctx)
372{
373   unsigned mask = ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->initial->info.gs.active_stream_mask & ~1;
374   for (unsigned i = 0; i < ctx->gfx_pipeline_state.so_info.num_outputs; ++i) {
375      unsigned stream = ctx->gfx_pipeline_state.so_info.output[i].stream;
376      if (((1 << stream) & mask) &&
377         ctx->so_buffer_views[stream].SizeInBytes)
378         return true;
379   }
380   return false;
381}
382
383static bool
384needs_point_sprite_lowering(struct d3d12_context *ctx, const struct pipe_draw_info *dinfo)
385{
386   struct d3d12_shader_selector *vs = ctx->gfx_stages[PIPE_SHADER_VERTEX];
387   struct d3d12_shader_selector *gs = ctx->gfx_stages[PIPE_SHADER_GEOMETRY];
388
389   if (gs != NULL && !gs->is_variant) {
390      /* There is an user GS; Check if it outputs points with PSIZE */
391      return (gs->initial->info.gs.output_primitive == GL_POINTS &&
392              (gs->initial->info.outputs_written & VARYING_BIT_PSIZ ||
393                 ctx->gfx_pipeline_state.rast->base.point_size > 1.0) &&
394              (gs->initial->info.gs.active_stream_mask == 1 ||
395                 !has_stream_out_for_streams(ctx)));
396   } else {
397      /* No user GS; check if we are drawing wide points */
398      return ((dinfo->mode == PIPE_PRIM_POINTS ||
399               fill_mode_lowered(ctx, dinfo) == PIPE_POLYGON_MODE_POINT) &&
400              (ctx->gfx_pipeline_state.rast->base.point_size > 1.0 ||
401               ctx->gfx_pipeline_state.rast->base.offset_point ||
402               (ctx->gfx_pipeline_state.rast->base.point_size_per_vertex &&
403                vs->initial->info.outputs_written & VARYING_BIT_PSIZ)) &&
404              (vs->initial->info.outputs_written & VARYING_BIT_POS));
405   }
406}
407
408static unsigned
409cull_mode_lowered(struct d3d12_context *ctx, unsigned fill_mode)
410{
411   if ((ctx->gfx_stages[PIPE_SHADER_GEOMETRY] != NULL &&
412        !ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->is_variant) ||
413       ctx->gfx_pipeline_state.rast == NULL ||
414       ctx->gfx_pipeline_state.rast->base.cull_face == PIPE_FACE_NONE)
415      return PIPE_FACE_NONE;
416
417   return ctx->gfx_pipeline_state.rast->base.cull_face;
418}
419
420static unsigned
421get_provoking_vertex(struct d3d12_selection_context *sel_ctx, bool *alternate, const struct pipe_draw_info *dinfo)
422{
423   if (dinfo->mode == GL_PATCHES) {
424      *alternate = false;
425      return 0;
426   }
427
428   struct d3d12_shader_selector *vs = sel_ctx->ctx->gfx_stages[PIPE_SHADER_VERTEX];
429   struct d3d12_shader_selector *gs = sel_ctx->ctx->gfx_stages[PIPE_SHADER_GEOMETRY];
430   struct d3d12_shader_selector *last_vertex_stage = gs && !gs->is_variant ? gs : vs;
431
432   /* Make sure GL prims match Gallium prims */
433   STATIC_ASSERT(GL_POINTS == PIPE_PRIM_POINTS);
434   STATIC_ASSERT(GL_LINES == PIPE_PRIM_LINES);
435   STATIC_ASSERT(GL_LINE_STRIP == PIPE_PRIM_LINE_STRIP);
436
437   enum pipe_prim_type mode;
438   switch (last_vertex_stage->stage) {
439   case PIPE_SHADER_GEOMETRY:
440      mode = (enum pipe_prim_type)last_vertex_stage->current->nir->info.gs.output_primitive;
441      break;
442   case PIPE_SHADER_VERTEX:
443      mode = (enum pipe_prim_type)dinfo->mode;
444      break;
445   default:
446      unreachable("Tesselation shaders are not supported");
447   }
448
449   bool flatshade_first = sel_ctx->ctx->gfx_pipeline_state.rast &&
450                          sel_ctx->ctx->gfx_pipeline_state.rast->base.flatshade_first;
451   *alternate = (mode == GL_TRIANGLE_STRIP || mode == GL_TRIANGLE_STRIP_ADJACENCY) &&
452                (!gs || gs->is_variant ||
453                 gs->initial->info.gs.vertices_out > u_prim_vertex_count(mode)->min);
454   return flatshade_first ? 0 : u_prim_vertex_count(mode)->min - 1;
455}
456
457static bool
458has_flat_varyings(struct d3d12_context *ctx)
459{
460   struct d3d12_shader_selector *fs = ctx->gfx_stages[PIPE_SHADER_FRAGMENT];
461
462   if (!fs || !fs->current)
463      return false;
464
465   nir_foreach_variable_with_modes(input, fs->current->nir,
466                                   nir_var_shader_in) {
467      if (input->data.interpolation == INTERP_MODE_FLAT &&
468          /* Disregard sysvals */
469          (input->data.location >= VARYING_SLOT_VAR0 ||
470             input->data.location <= VARYING_SLOT_TEX7))
471         return true;
472   }
473
474   return false;
475}
476
477static bool
478needs_vertex_reordering(struct d3d12_selection_context *sel_ctx, const struct pipe_draw_info *dinfo)
479{
480   struct d3d12_context *ctx = sel_ctx->ctx;
481   bool flat = has_flat_varyings(ctx);
482   bool xfb = ctx->gfx_pipeline_state.num_so_targets > 0;
483
484   if (fill_mode_lowered(ctx, dinfo) != PIPE_POLYGON_MODE_FILL)
485      return false;
486
487   /* TODO add support for line primitives */
488
489   /* When flat shading a triangle and provoking vertex is not the first one, we use load_at_vertex.
490      If not available for this adapter, or if it's a triangle strip, we need to reorder the vertices */
491   if (flat && sel_ctx->provoking_vertex >= 2 && (!d3d12_screen(ctx->base.screen)->have_load_at_vertex ||
492                                                  sel_ctx->alternate_tri))
493      return true;
494
495   /* When transform feedback is enabled and the output is alternating (triangle strip or triangle
496      strip with adjacency), we need to reorder vertices to get the order expected by OpenGL. This
497      only works when there is no flat shading involved. In that scenario, we don't care about
498      the provoking vertex. */
499   if (xfb && !flat && sel_ctx->alternate_tri) {
500      sel_ctx->provoking_vertex = 0;
501      return true;
502   }
503
504   return false;
505}
506
507static nir_variable *
508create_varying_from_info(nir_shader *nir, struct d3d12_varying_info *info,
509                         unsigned slot, unsigned slot_frac, nir_variable_mode mode, bool patch)
510{
511   nir_variable *var;
512   char tmp[100];
513
514   snprintf(tmp, ARRAY_SIZE(tmp),
515            mode == nir_var_shader_in ? "in_%d" : "out_%d",
516            info->slots[slot].vars[slot_frac].driver_location);
517   var = nir_variable_create(nir, mode, info->slots[slot].types[slot_frac], tmp);
518   var->data.location = slot;
519   var->data.location_frac = slot_frac;
520   var->data.driver_location = info->slots[slot].vars[slot_frac].driver_location;
521   var->data.interpolation = info->slots[slot].vars[slot_frac].interpolation;
522   var->data.patch = info->slots[slot].patch;
523   var->data.compact = info->slots[slot].vars[slot_frac].compact;
524   if (patch)
525      var->data.location += VARYING_SLOT_PATCH0;
526
527   if (mode == nir_var_shader_out)
528      NIR_PASS_V(nir, d3d12_write_0_to_new_varying, var);
529
530   return var;
531}
532
533void
534create_varyings_from_info(nir_shader *nir, struct d3d12_varying_info *info,
535                          unsigned slot, nir_variable_mode mode, bool patch)
536{
537   unsigned mask = info->slots[slot].location_frac_mask;
538   while (mask)
539      create_varying_from_info(nir, info, slot, u_bit_scan(&mask), mode, patch);
540}
541
542static void
543fill_varyings(struct d3d12_varying_info *info, nir_shader *s,
544              nir_variable_mode modes, uint64_t mask, bool patch)
545{
546   nir_foreach_variable_with_modes(var, s, modes) {
547      unsigned slot = var->data.location;
548      bool is_generic_patch = slot >= VARYING_SLOT_PATCH0;
549      if (patch ^ is_generic_patch)
550         continue;
551      if (is_generic_patch)
552         slot -= VARYING_SLOT_PATCH0;
553      uint64_t slot_bit = BITFIELD64_BIT(slot);
554
555      if (!(mask & slot_bit))
556         continue;
557
558      const struct glsl_type *type = var->type;
559      if ((s->info.stage == MESA_SHADER_GEOMETRY ||
560           s->info.stage == MESA_SHADER_TESS_CTRL) &&
561          (modes & nir_var_shader_in) &&
562          glsl_type_is_array(type))
563         type = glsl_get_array_element(type);
564      info->slots[slot].types[var->data.location_frac] = type;
565
566      info->slots[slot].patch = var->data.patch;
567      auto& var_slot = info->slots[slot].vars[var->data.location_frac];
568      var_slot.driver_location = var->data.driver_location;
569      var_slot.interpolation = var->data.interpolation;
570      var_slot.compact = var->data.compact;
571      info->mask |= slot_bit;
572      info->slots[slot].location_frac_mask |= (1 << var->data.location_frac);
573   }
574}
575
576static void
577fill_flat_varyings(struct d3d12_gs_variant_key *key, d3d12_shader_selector *fs)
578{
579   if (!fs || !fs->current)
580      return;
581
582   nir_foreach_variable_with_modes(input, fs->current->nir,
583                                   nir_var_shader_in) {
584      if (input->data.interpolation == INTERP_MODE_FLAT)
585         key->flat_varyings |= BITFIELD64_BIT(input->data.location);
586   }
587}
588
589static void
590validate_geometry_shader_variant(struct d3d12_selection_context *sel_ctx)
591{
592   struct d3d12_context *ctx = sel_ctx->ctx;
593   d3d12_shader_selector *vs = ctx->gfx_stages[PIPE_SHADER_VERTEX];
594   d3d12_shader_selector *fs = ctx->gfx_stages[PIPE_SHADER_FRAGMENT];
595   struct d3d12_gs_variant_key key = {0};
596   bool variant_needed = false;
597
598   d3d12_shader_selector *gs = ctx->gfx_stages[PIPE_SHADER_GEOMETRY];
599
600   /* Nothing to do if there is a user geometry shader bound */
601   if (gs != NULL && !gs->is_variant)
602      return;
603
604   /* Fill the geometry shader variant key */
605   if (sel_ctx->fill_mode_lowered != PIPE_POLYGON_MODE_FILL) {
606      key.fill_mode = sel_ctx->fill_mode_lowered;
607      key.cull_mode = sel_ctx->cull_mode_lowered;
608      key.has_front_face = BITSET_TEST(fs->initial->info.system_values_read, SYSTEM_VALUE_FRONT_FACE);
609      if (key.cull_mode != PIPE_FACE_NONE || key.has_front_face)
610         key.front_ccw = ctx->gfx_pipeline_state.rast->base.front_ccw ^ (ctx->flip_y < 0);
611      key.edge_flag_fix = needs_edge_flag_fix(ctx->initial_api_prim);
612      fill_flat_varyings(&key, fs);
613      if (key.flat_varyings != 0)
614         key.flatshade_first = ctx->gfx_pipeline_state.rast->base.flatshade_first;
615      variant_needed = true;
616   } else if (sel_ctx->needs_point_sprite_lowering) {
617      key.passthrough = true;
618      variant_needed = true;
619   } else if (sel_ctx->needs_vertex_reordering) {
620      /* TODO support cases where flat shading (pv != 0) and xfb are enabled */
621      key.provoking_vertex = sel_ctx->provoking_vertex;
622      key.alternate_tri = sel_ctx->alternate_tri;
623      variant_needed = true;
624   }
625
626   if (variant_needed) {
627      fill_varyings(&key.varyings, vs->initial, nir_var_shader_out,
628                    vs->initial->info.outputs_written, false);
629   }
630
631   /* Check if the currently bound geometry shader variant is correct */
632   if (gs && memcmp(&gs->gs_key, &key, sizeof(key)) == 0)
633      return;
634
635   /* Find/create the proper variant and bind it */
636   gs = variant_needed ? d3d12_get_gs_variant(ctx, &key) : NULL;
637   ctx->gfx_stages[PIPE_SHADER_GEOMETRY] = gs;
638}
639
640static void
641validate_tess_ctrl_shader_variant(struct d3d12_selection_context *sel_ctx)
642{
643   struct d3d12_context *ctx = sel_ctx->ctx;
644   d3d12_shader_selector *vs = ctx->gfx_stages[PIPE_SHADER_VERTEX];
645   d3d12_shader_selector *tcs = ctx->gfx_stages[PIPE_SHADER_TESS_CTRL];
646   d3d12_shader_selector *tes = ctx->gfx_stages[PIPE_SHADER_TESS_EVAL];
647   struct d3d12_tcs_variant_key key = {0};
648
649   /* Nothing to do if there is a user tess ctrl shader bound */
650   if (tcs != NULL && !tcs->is_variant)
651      return;
652
653   bool variant_needed = tes != nullptr;
654
655   /* Fill the variant key */
656   if (variant_needed) {
657      fill_varyings(&key.varyings, vs->initial, nir_var_shader_out,
658                    vs->initial->info.outputs_written, false);
659      key.vertices_out = ctx->patch_vertices;
660   }
661
662   /* Check if the currently bound tessellation control shader variant is correct */
663   if (tcs && memcmp(&tcs->tcs_key, &key, sizeof(key)) == 0)
664      return;
665
666   /* Find/create the proper variant and bind it */
667   tcs = variant_needed ? d3d12_get_tcs_variant(ctx, &key) : NULL;
668   ctx->gfx_stages[PIPE_SHADER_TESS_CTRL] = tcs;
669}
670
671static bool
672d3d12_compare_varying_info(const d3d12_varying_info *expect, const d3d12_varying_info *have)
673{
674   if (expect->mask != have->mask)
675      return false;
676
677   if (!expect->mask)
678      return true;
679
680   /* 6 is a rough (wild) guess for a bulk memcmp cross-over point.  When there
681    * are a small number of slots present, individual memcmp is much faster. */
682   if (util_bitcount64(expect->mask) < 6) {
683      uint64_t mask = expect->mask;
684      while (mask) {
685         int slot = u_bit_scan64(&mask);
686         if (memcmp(&expect->slots[slot], &have->slots[slot], sizeof(have->slots[slot])))
687            return false;
688      }
689
690      return true;
691   }
692
693   return !memcmp(expect, have, sizeof(struct d3d12_varying_info));
694}
695
696static bool
697d3d12_compare_shader_keys(const d3d12_shader_key *expect, const d3d12_shader_key *have)
698{
699   assert(expect->stage == have->stage);
700   assert(expect);
701   assert(have);
702
703   if (expect->hash != have->hash)
704      return false;
705
706   /* Because we only add varyings we check that a shader has at least the expected in-
707    * and outputs. */
708
709   if (!d3d12_compare_varying_info(&expect->required_varying_inputs,
710                                   &have->required_varying_inputs) ||
711       expect->next_varying_inputs != have->next_varying_inputs)
712      return false;
713
714   if (!d3d12_compare_varying_info(&expect->required_varying_outputs,
715                                   &have->required_varying_outputs) ||
716       expect->prev_varying_outputs != have->prev_varying_outputs)
717      return false;
718
719   if (expect->stage == PIPE_SHADER_GEOMETRY) {
720      if (expect->gs.writes_psize) {
721         if (!have->gs.writes_psize ||
722             expect->gs.point_pos_stream_out != have->gs.point_pos_stream_out ||
723             expect->gs.sprite_coord_enable != have->gs.sprite_coord_enable ||
724             expect->gs.sprite_origin_upper_left != have->gs.sprite_origin_upper_left ||
725             expect->gs.point_size_per_vertex != have->gs.point_size_per_vertex)
726            return false;
727      } else if (have->gs.writes_psize) {
728         return false;
729      }
730      if (expect->gs.primitive_id != have->gs.primitive_id ||
731          expect->gs.triangle_strip != have->gs.triangle_strip)
732         return false;
733   } else if (expect->stage == PIPE_SHADER_FRAGMENT) {
734      if (expect->fs.frag_result_color_lowering != have->fs.frag_result_color_lowering ||
735          expect->fs.manual_depth_range != have->fs.manual_depth_range ||
736          expect->fs.polygon_stipple != have->fs.polygon_stipple ||
737          expect->fs.cast_to_uint != have->fs.cast_to_uint ||
738          expect->fs.cast_to_int != have->fs.cast_to_int ||
739          expect->fs.remap_front_facing != have->fs.remap_front_facing ||
740          expect->fs.missing_dual_src_outputs != have->fs.missing_dual_src_outputs ||
741          expect->fs.multisample_disabled != have->fs.multisample_disabled)
742         return false;
743   } else if (expect->stage == PIPE_SHADER_COMPUTE) {
744      if (memcmp(expect->cs.workgroup_size, have->cs.workgroup_size,
745                 sizeof(have->cs.workgroup_size)))
746         return false;
747   } else if (expect->stage == PIPE_SHADER_TESS_CTRL) {
748      if (expect->hs.primitive_mode != have->hs.primitive_mode ||
749          expect->hs.ccw != have->hs.ccw ||
750          expect->hs.point_mode != have->hs.point_mode ||
751          expect->hs.spacing != have->hs.spacing ||
752          expect->hs.patch_vertices_in != have->hs.patch_vertices_in ||
753          memcmp(&expect->hs.required_patch_outputs, &have->hs.required_patch_outputs,
754                 sizeof(struct d3d12_varying_info)) ||
755          expect->hs.next_patch_inputs != have->hs.next_patch_inputs)
756         return false;
757   } else if (expect->stage == PIPE_SHADER_TESS_EVAL) {
758      if (expect->ds.tcs_vertices_out != have->ds.tcs_vertices_out ||
759          memcmp(&expect->ds.required_patch_inputs, &have->ds.required_patch_inputs,
760                 sizeof(struct d3d12_varying_info)) ||
761          expect->ds.prev_patch_outputs != have ->ds.prev_patch_outputs)
762         return false;
763   }
764
765   if (expect->input_clip_size != have->input_clip_size)
766      return false;
767
768   if (expect->tex_saturate_s != have->tex_saturate_s ||
769       expect->tex_saturate_r != have->tex_saturate_r ||
770       expect->tex_saturate_t != have->tex_saturate_t)
771      return false;
772
773   if (expect->samples_int_textures != have->samples_int_textures)
774      return false;
775
776   if (expect->n_texture_states != have->n_texture_states)
777      return false;
778
779   if (expect->n_images != have->n_images)
780      return false;
781
782   if (memcmp(expect->tex_wrap_states, have->tex_wrap_states,
783              expect->n_texture_states * sizeof(dxil_wrap_sampler_state)))
784      return false;
785
786   if (memcmp(expect->swizzle_state, have->swizzle_state,
787              expect->n_texture_states * sizeof(dxil_texture_swizzle_state)))
788      return false;
789
790   if (memcmp(expect->sampler_compare_funcs, have->sampler_compare_funcs,
791              expect->n_texture_states * sizeof(enum compare_func)))
792      return false;
793
794   if (memcmp(expect->image_format_conversion, have->image_format_conversion,
795      expect->n_images * sizeof(struct d3d12_image_format_conversion_info)))
796      return false;
797
798   if (expect->invert_depth != have->invert_depth ||
799       expect->halfz != have->halfz)
800      return false;
801
802   if (expect->stage == PIPE_SHADER_VERTEX) {
803      if (expect->vs.needs_format_emulation != have->vs.needs_format_emulation)
804         return false;
805
806      if (expect->vs.needs_format_emulation) {
807         if (memcmp(expect->vs.format_conversion, have->vs.format_conversion,
808                    PIPE_MAX_ATTRIBS * sizeof (enum pipe_format)))
809            return false;
810      }
811   }
812
813   if (expect->fs.provoking_vertex != have->fs.provoking_vertex)
814      return false;
815
816   return true;
817}
818
819static uint32_t
820d3d12_shader_key_hash(const d3d12_shader_key *key)
821{
822   uint32_t hash;
823
824   hash = (uint32_t)key->stage;
825   hash += key->required_varying_inputs.mask;
826   hash += key->required_varying_outputs.mask;
827   hash += key->next_varying_inputs;
828   hash += key->prev_varying_outputs;
829   switch (key->stage) {
830   case PIPE_SHADER_VERTEX:
831      /* (Probably) not worth the bit extraction for needs_format_emulation and
832       * the rest of the the format_conversion data is large.  Don't bother
833       * hashing for now until this is shown to be worthwhile. */
834       break;
835   case PIPE_SHADER_GEOMETRY:
836      hash = _mesa_hash_data_with_seed(&key->gs, sizeof(key->gs), hash);
837      break;
838   case PIPE_SHADER_FRAGMENT:
839      hash = _mesa_hash_data_with_seed(&key->fs, sizeof(key->fs), hash);
840      break;
841   case PIPE_SHADER_COMPUTE:
842      hash = _mesa_hash_data_with_seed(&key->cs, sizeof(key->cs), hash);
843      break;
844   case PIPE_SHADER_TESS_CTRL:
845      hash += key->hs.next_patch_inputs;
846      break;
847   case PIPE_SHADER_TESS_EVAL:
848      hash += key->ds.tcs_vertices_out;
849      hash += key->ds.prev_patch_outputs;
850      break;
851   default:
852      /* No type specific information to hash for other stages. */
853      break;
854   }
855
856   hash += key->n_texture_states;
857   hash += key->n_images;
858   return hash;
859}
860
861static void
862d3d12_fill_shader_key(struct d3d12_selection_context *sel_ctx,
863                      d3d12_shader_key *key, d3d12_shader_selector *sel,
864                      d3d12_shader_selector *prev, d3d12_shader_selector *next)
865{
866   pipe_shader_type stage = sel->stage;
867
868   uint64_t system_generated_in_values =
869         VARYING_BIT_PNTC |
870         VARYING_BIT_PRIMITIVE_ID;
871
872   uint64_t system_out_values =
873         VARYING_BIT_CLIP_DIST0 |
874         VARYING_BIT_CLIP_DIST1;
875
876   memset(key, 0, sizeof(d3d12_shader_key));
877   key->stage = stage;
878
879   if (prev) {
880      /* We require as inputs what the previous stage has written,
881       * except certain system values */
882      if (stage == PIPE_SHADER_FRAGMENT || stage == PIPE_SHADER_GEOMETRY)
883         system_out_values |= VARYING_BIT_POS;
884      if (stage == PIPE_SHADER_FRAGMENT)
885         system_out_values |= VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT | VARYING_BIT_LAYER;
886      uint64_t mask = prev->current->nir->info.outputs_written & ~system_out_values;
887      fill_varyings(&key->required_varying_inputs, prev->current->nir,
888                    nir_var_shader_out, mask, false);
889      key->prev_varying_outputs = prev->current->nir->info.outputs_written;
890
891      if (stage == PIPE_SHADER_TESS_EVAL) {
892         uint32_t patch_mask = prev->current->nir->info.patch_outputs_written;
893         fill_varyings(&key->ds.required_patch_inputs, prev->current->nir,
894                       nir_var_shader_out, patch_mask, true);
895         key->ds.prev_patch_outputs = patch_mask;
896      }
897
898      /* Set the provoking vertex based on the previous shader output. Only set the
899       * key value if the driver actually supports changing the provoking vertex though */
900      if (stage == PIPE_SHADER_FRAGMENT && sel_ctx->ctx->gfx_pipeline_state.rast &&
901          !sel_ctx->needs_vertex_reordering &&
902          d3d12_screen(sel_ctx->ctx->base.screen)->have_load_at_vertex)
903         key->fs.provoking_vertex = sel_ctx->provoking_vertex;
904
905      /* Get the input clip distance size. The info's clip_distance_array_size corresponds
906       * to the output, and in cases of TES or GS you could have differently-sized inputs
907       * and outputs. For FS, there is no output, so it's repurposed to mean input.
908       */
909      if (stage != PIPE_SHADER_FRAGMENT)
910         key->input_clip_size = prev->current->nir->info.clip_distance_array_size;
911   }
912
913   /* We require as outputs what the next stage reads,
914    * except certain system values */
915   if (next) {
916      if (!next->is_variant) {
917         if (stage == PIPE_SHADER_VERTEX)
918            system_generated_in_values |= VARYING_BIT_POS;
919         uint64_t mask = next->current->nir->info.inputs_read & ~system_generated_in_values;
920         fill_varyings(&key->required_varying_outputs, next->current->nir,
921                       nir_var_shader_in, mask, false);
922
923         if (stage == PIPE_SHADER_TESS_CTRL) {
924            uint32_t patch_mask = next->current->nir->info.patch_outputs_read;
925            fill_varyings(&key->hs.required_patch_outputs, prev->current->nir,
926                          nir_var_shader_in, patch_mask, true);
927            key->hs.next_patch_inputs = patch_mask;
928         }
929      }
930      key->next_varying_inputs = next->current->nir->info.inputs_read;
931
932   }
933
934   if (stage == PIPE_SHADER_GEOMETRY ||
935       ((stage == PIPE_SHADER_VERTEX || stage == PIPE_SHADER_TESS_EVAL) &&
936          (!next || next->stage == PIPE_SHADER_FRAGMENT))) {
937      key->last_vertex_processing_stage = 1;
938      key->invert_depth = sel_ctx->ctx->reverse_depth_range;
939      key->halfz = sel_ctx->ctx->gfx_pipeline_state.rast ?
940         sel_ctx->ctx->gfx_pipeline_state.rast->base.clip_halfz : false;
941      if (sel_ctx->ctx->pstipple.enabled &&
942         sel_ctx->ctx->gfx_pipeline_state.rast->base.poly_stipple_enable)
943         key->next_varying_inputs |= VARYING_BIT_POS;
944   }
945
946   if (stage == PIPE_SHADER_GEOMETRY && sel_ctx->ctx->gfx_pipeline_state.rast) {
947      struct pipe_rasterizer_state *rast = &sel_ctx->ctx->gfx_pipeline_state.rast->base;
948      if (sel_ctx->needs_point_sprite_lowering) {
949         key->gs.writes_psize = 1;
950         key->gs.point_size_per_vertex = rast->point_size_per_vertex;
951         key->gs.sprite_coord_enable = rast->sprite_coord_enable;
952         key->gs.sprite_origin_upper_left = (rast->sprite_coord_mode != PIPE_SPRITE_COORD_LOWER_LEFT);
953         if (sel_ctx->ctx->flip_y < 0)
954            key->gs.sprite_origin_upper_left = !key->gs.sprite_origin_upper_left;
955         key->gs.aa_point = rast->point_smooth;
956         key->gs.stream_output_factor = 6;
957      } else if (sel_ctx->fill_mode_lowered == PIPE_POLYGON_MODE_LINE) {
958         key->gs.stream_output_factor = 2;
959      } else if (sel_ctx->needs_vertex_reordering && !sel->is_variant) {
960         key->gs.triangle_strip = 1;
961      }
962
963      if (sel->is_variant && next && next->initial->info.inputs_read & VARYING_BIT_PRIMITIVE_ID)
964         key->gs.primitive_id = 1;
965   } else if (stage == PIPE_SHADER_FRAGMENT) {
966      key->fs.missing_dual_src_outputs = sel_ctx->missing_dual_src_outputs;
967      key->fs.frag_result_color_lowering = sel_ctx->frag_result_color_lowering;
968      key->fs.manual_depth_range = sel_ctx->manual_depth_range;
969      key->fs.polygon_stipple = sel_ctx->ctx->pstipple.enabled &&
970         sel_ctx->ctx->gfx_pipeline_state.rast->base.poly_stipple_enable;
971      key->fs.multisample_disabled = sel_ctx->ctx->gfx_pipeline_state.rast &&
972         !sel_ctx->ctx->gfx_pipeline_state.rast->desc.MultisampleEnable;
973      if (sel_ctx->ctx->gfx_pipeline_state.blend &&
974          sel_ctx->ctx->gfx_pipeline_state.blend->desc.RenderTarget[0].LogicOpEnable &&
975          !sel_ctx->ctx->gfx_pipeline_state.has_float_rtv) {
976         key->fs.cast_to_uint = util_format_is_unorm(sel_ctx->ctx->fb.cbufs[0]->format);
977         key->fs.cast_to_int = !key->fs.cast_to_uint;
978      }
979   } else if (stage == PIPE_SHADER_TESS_CTRL) {
980      if (next && next->current->nir->info.stage == MESA_SHADER_TESS_EVAL) {
981         key->hs.primitive_mode = next->current->nir->info.tess._primitive_mode;
982         key->hs.ccw = next->current->nir->info.tess.ccw;
983         key->hs.point_mode = next->current->nir->info.tess.point_mode;
984         key->hs.spacing = next->current->nir->info.tess.spacing;
985      } else {
986         key->hs.primitive_mode = TESS_PRIMITIVE_QUADS;
987         key->hs.ccw = true;
988         key->hs.point_mode = false;
989         key->hs.spacing = TESS_SPACING_EQUAL;
990      }
991      key->hs.patch_vertices_in = MAX2(sel_ctx->ctx->patch_vertices, 1);
992   } else if (stage == PIPE_SHADER_TESS_EVAL) {
993      if (prev && prev->current->nir->info.stage == MESA_SHADER_TESS_CTRL)
994         key->ds.tcs_vertices_out = prev->current->nir->info.tess.tcs_vertices_out;
995      else
996         key->ds.tcs_vertices_out = 32;
997   }
998
999   if (sel->samples_int_textures) {
1000      key->samples_int_textures = sel->samples_int_textures;
1001      key->n_texture_states = sel_ctx->ctx->num_sampler_views[stage];
1002      /* Copy only states with integer textures */
1003      for(int i = 0; i < key->n_texture_states; ++i) {
1004         auto& wrap_state = sel_ctx->ctx->tex_wrap_states[stage][i];
1005         if (wrap_state.is_int_sampler) {
1006            memcpy(&key->tex_wrap_states[i], &wrap_state, sizeof(wrap_state));
1007            key->swizzle_state[i] = sel_ctx->ctx->tex_swizzle_state[stage][i];
1008         }
1009      }
1010   }
1011
1012   for (unsigned i = 0; i < sel_ctx->ctx->num_samplers[stage]; ++i) {
1013      if (!sel_ctx->ctx->samplers[stage][i] ||
1014          sel_ctx->ctx->samplers[stage][i]->filter == PIPE_TEX_FILTER_NEAREST)
1015         continue;
1016
1017      if (sel_ctx->ctx->samplers[stage][i]->wrap_r == PIPE_TEX_WRAP_CLAMP)
1018         key->tex_saturate_r |= 1 << i;
1019      if (sel_ctx->ctx->samplers[stage][i]->wrap_s == PIPE_TEX_WRAP_CLAMP)
1020         key->tex_saturate_s |= 1 << i;
1021      if (sel_ctx->ctx->samplers[stage][i]->wrap_t == PIPE_TEX_WRAP_CLAMP)
1022         key->tex_saturate_t |= 1 << i;
1023   }
1024
1025   if (sel->compare_with_lod_bias_grad) {
1026      key->n_texture_states = sel_ctx->ctx->num_sampler_views[stage];
1027      memcpy(key->sampler_compare_funcs, sel_ctx->ctx->tex_compare_func[stage],
1028             key->n_texture_states * sizeof(enum compare_func));
1029      memcpy(key->swizzle_state, sel_ctx->ctx->tex_swizzle_state[stage],
1030             key->n_texture_states * sizeof(dxil_texture_swizzle_state));
1031   }
1032
1033   if (stage == PIPE_SHADER_VERTEX && sel_ctx->ctx->gfx_pipeline_state.ves) {
1034      key->vs.needs_format_emulation = sel_ctx->ctx->gfx_pipeline_state.ves->needs_format_emulation;
1035      if (key->vs.needs_format_emulation) {
1036         memcpy(key->vs.format_conversion, sel_ctx->ctx->gfx_pipeline_state.ves->format_conversion,
1037                sel_ctx->ctx->gfx_pipeline_state.ves->num_elements * sizeof(enum pipe_format));
1038      }
1039   }
1040
1041   if (stage == PIPE_SHADER_FRAGMENT &&
1042       sel_ctx->ctx->gfx_stages[PIPE_SHADER_GEOMETRY] &&
1043       sel_ctx->ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->is_variant &&
1044       sel_ctx->ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->gs_key.has_front_face) {
1045      key->fs.remap_front_facing = 1;
1046   }
1047
1048   if (stage == PIPE_SHADER_COMPUTE && sel_ctx->variable_workgroup_size) {
1049      memcpy(key->cs.workgroup_size, sel_ctx->variable_workgroup_size, sizeof(key->cs.workgroup_size));
1050   }
1051
1052   key->n_images = sel_ctx->ctx->num_image_views[stage];
1053   for (int i = 0; i < key->n_images; ++i) {
1054      key->image_format_conversion[i].emulated_format = sel_ctx->ctx->image_view_emulation_formats[stage][i];
1055      if (key->image_format_conversion[i].emulated_format != PIPE_FORMAT_NONE)
1056         key->image_format_conversion[i].view_format = sel_ctx->ctx->image_views[stage][i].format;
1057   }
1058
1059   key->hash = d3d12_shader_key_hash(key);
1060}
1061
1062static void
1063select_shader_variant(struct d3d12_selection_context *sel_ctx, d3d12_shader_selector *sel,
1064                     d3d12_shader_selector *prev, d3d12_shader_selector *next)
1065{
1066   struct d3d12_context *ctx = sel_ctx->ctx;
1067   d3d12_shader_key key;
1068   nir_shader *new_nir_variant;
1069   unsigned pstipple_binding = UINT32_MAX;
1070
1071   d3d12_fill_shader_key(sel_ctx, &key, sel, prev, next);
1072
1073   /* Check for an existing variant */
1074   for (d3d12_shader *variant = sel->first; variant;
1075        variant = variant->next_variant) {
1076
1077      if (d3d12_compare_shader_keys(&key, &variant->key)) {
1078         sel->current = variant;
1079         return;
1080      }
1081   }
1082
1083   /* Clone the NIR shader */
1084   new_nir_variant = nir_shader_clone(sel, sel->initial);
1085
1086   /* Apply any needed lowering passes */
1087   if (key.gs.writes_psize) {
1088      NIR_PASS_V(new_nir_variant, d3d12_lower_point_sprite,
1089                 !key.gs.sprite_origin_upper_left,
1090                 key.gs.point_size_per_vertex,
1091                 key.gs.sprite_coord_enable,
1092                 key.next_varying_inputs);
1093
1094      nir_function_impl *impl = nir_shader_get_entrypoint(new_nir_variant);
1095      nir_shader_gather_info(new_nir_variant, impl);
1096   }
1097
1098   if (key.gs.primitive_id) {
1099      NIR_PASS_V(new_nir_variant, d3d12_lower_primitive_id);
1100
1101      nir_function_impl *impl = nir_shader_get_entrypoint(new_nir_variant);
1102      nir_shader_gather_info(new_nir_variant, impl);
1103   }
1104
1105   if (key.gs.triangle_strip)
1106      NIR_PASS_V(new_nir_variant, d3d12_lower_triangle_strip);
1107
1108   if (key.fs.polygon_stipple) {
1109      NIR_PASS_V(new_nir_variant, nir_lower_pstipple_fs,
1110                 &pstipple_binding, 0, false);
1111
1112      nir_function_impl *impl = nir_shader_get_entrypoint(new_nir_variant);
1113      nir_shader_gather_info(new_nir_variant, impl);
1114   }
1115
1116   if (key.fs.remap_front_facing) {
1117      d3d12_forward_front_face(new_nir_variant);
1118
1119      nir_function_impl *impl = nir_shader_get_entrypoint(new_nir_variant);
1120      nir_shader_gather_info(new_nir_variant, impl);
1121   }
1122
1123   if (key.fs.missing_dual_src_outputs) {
1124      NIR_PASS_V(new_nir_variant, d3d12_add_missing_dual_src_target,
1125                 key.fs.missing_dual_src_outputs);
1126   } else if (key.fs.frag_result_color_lowering) {
1127      NIR_PASS_V(new_nir_variant, nir_lower_fragcolor,
1128                 key.fs.frag_result_color_lowering);
1129   }
1130
1131   if (key.fs.manual_depth_range)
1132      NIR_PASS_V(new_nir_variant, d3d12_lower_depth_range);
1133
1134   if (sel->compare_with_lod_bias_grad) {
1135      STATIC_ASSERT(sizeof(dxil_texture_swizzle_state) ==
1136                    sizeof(nir_lower_tex_shadow_swizzle));
1137
1138      NIR_PASS_V(new_nir_variant, nir_lower_tex_shadow, key.n_texture_states,
1139                 key.sampler_compare_funcs, (nir_lower_tex_shadow_swizzle *)key.swizzle_state);
1140   }
1141
1142   if (key.fs.cast_to_uint)
1143      NIR_PASS_V(new_nir_variant, d3d12_lower_uint_cast, false);
1144   if (key.fs.cast_to_int)
1145      NIR_PASS_V(new_nir_variant, d3d12_lower_uint_cast, true);
1146
1147   if (key.n_images)
1148      NIR_PASS_V(new_nir_variant, d3d12_lower_image_casts, key.image_format_conversion);
1149
1150   if (sel->workgroup_size_variable) {
1151      new_nir_variant->info.workgroup_size[0] = key.cs.workgroup_size[0];
1152      new_nir_variant->info.workgroup_size[1] = key.cs.workgroup_size[1];
1153      new_nir_variant->info.workgroup_size[2] = key.cs.workgroup_size[2];
1154   }
1155
1156   if (new_nir_variant->info.stage == MESA_SHADER_TESS_CTRL) {
1157      new_nir_variant->info.tess._primitive_mode = (tess_primitive_mode)key.hs.primitive_mode;
1158      new_nir_variant->info.tess.ccw = key.hs.ccw;
1159      new_nir_variant->info.tess.point_mode = key.hs.point_mode;
1160      new_nir_variant->info.tess.spacing = key.hs.spacing;
1161
1162      NIR_PASS_V(new_nir_variant, dxil_nir_set_tcs_patches_in, key.hs.patch_vertices_in);
1163   } else if (new_nir_variant->info.stage == MESA_SHADER_TESS_EVAL) {
1164      new_nir_variant->info.tess.tcs_vertices_out = key.ds.tcs_vertices_out;
1165   }
1166
1167   {
1168      struct nir_lower_tex_options tex_options = { };
1169      tex_options.lower_txp = ~0u; /* No equivalent for textureProj */
1170      tex_options.lower_rect = true;
1171      tex_options.lower_rect_offset = true;
1172      tex_options.saturate_s = key.tex_saturate_s;
1173      tex_options.saturate_r = key.tex_saturate_r;
1174      tex_options.saturate_t = key.tex_saturate_t;
1175      tex_options.lower_invalid_implicit_lod = true;
1176      tex_options.lower_tg4_offsets = true;
1177
1178      NIR_PASS_V(new_nir_variant, nir_lower_tex, &tex_options);
1179   }
1180
1181   /* Add the needed in and outputs, and re-sort */
1182   if (prev) {
1183      uint64_t mask = key.required_varying_inputs.mask & ~new_nir_variant->info.inputs_read;
1184      new_nir_variant->info.inputs_read |= mask;
1185      while (mask) {
1186         int slot = u_bit_scan64(&mask);
1187         create_varyings_from_info(new_nir_variant, &key.required_varying_inputs, slot, nir_var_shader_in, false);
1188      }
1189
1190      if (sel->stage == PIPE_SHADER_TESS_EVAL) {
1191         uint32_t patch_mask = (uint32_t)key.ds.required_patch_inputs.mask & ~new_nir_variant->info.patch_inputs_read;
1192         new_nir_variant->info.patch_inputs_read |= patch_mask;
1193         while (patch_mask) {
1194            int slot = u_bit_scan(&patch_mask);
1195            create_varyings_from_info(new_nir_variant, &key.ds.required_patch_inputs, slot, nir_var_shader_in, true);
1196         }
1197      }
1198      dxil_reassign_driver_locations(new_nir_variant, nir_var_shader_in,
1199                                      key.prev_varying_outputs);
1200   }
1201
1202
1203   if (next) {
1204      uint64_t mask = key.required_varying_outputs.mask & ~new_nir_variant->info.outputs_written;
1205      new_nir_variant->info.outputs_written |= mask;
1206      while (mask) {
1207         int slot = u_bit_scan64(&mask);
1208         create_varyings_from_info(new_nir_variant, &key.required_varying_outputs, slot, nir_var_shader_out, false);
1209      }
1210
1211      if (sel->stage == PIPE_SHADER_TESS_CTRL) {
1212         uint32_t patch_mask = (uint32_t)key.hs.required_patch_outputs.mask & ~new_nir_variant->info.patch_outputs_written;
1213         new_nir_variant->info.patch_outputs_written |= patch_mask;
1214         while (patch_mask) {
1215            int slot = u_bit_scan(&patch_mask);
1216            create_varyings_from_info(new_nir_variant, &key.ds.required_patch_inputs, slot, nir_var_shader_out, true);
1217         }
1218      }
1219      dxil_reassign_driver_locations(new_nir_variant, nir_var_shader_out,
1220                                     key.next_varying_inputs);
1221   }
1222
1223   d3d12_shader *new_variant = compile_nir(ctx, sel, &key, new_nir_variant);
1224   assert(new_variant);
1225
1226   /* keep track of polygon stipple texture binding */
1227   new_variant->pstipple_binding = pstipple_binding;
1228
1229   /* prepend the new shader in the selector chain and pick it */
1230   new_variant->next_variant = sel->first;
1231   sel->current = sel->first = new_variant;
1232}
1233
1234static d3d12_shader_selector *
1235get_prev_shader(struct d3d12_context *ctx, pipe_shader_type current)
1236{
1237   switch (current) {
1238   case PIPE_SHADER_VERTEX:
1239      return NULL;
1240   case PIPE_SHADER_FRAGMENT:
1241      if (ctx->gfx_stages[PIPE_SHADER_GEOMETRY])
1242         return ctx->gfx_stages[PIPE_SHADER_GEOMETRY];
1243      FALLTHROUGH;
1244   case PIPE_SHADER_GEOMETRY:
1245      if (ctx->gfx_stages[PIPE_SHADER_TESS_EVAL])
1246         return ctx->gfx_stages[PIPE_SHADER_TESS_EVAL];
1247      FALLTHROUGH;
1248   case PIPE_SHADER_TESS_EVAL:
1249      if (ctx->gfx_stages[PIPE_SHADER_TESS_CTRL])
1250         return ctx->gfx_stages[PIPE_SHADER_TESS_CTRL];
1251      FALLTHROUGH;
1252   case PIPE_SHADER_TESS_CTRL:
1253      return ctx->gfx_stages[PIPE_SHADER_VERTEX];
1254   default:
1255      unreachable("shader type not supported");
1256   }
1257}
1258
1259static d3d12_shader_selector *
1260get_next_shader(struct d3d12_context *ctx, pipe_shader_type current)
1261{
1262   switch (current) {
1263   case PIPE_SHADER_VERTEX:
1264      if (ctx->gfx_stages[PIPE_SHADER_TESS_CTRL])
1265         return ctx->gfx_stages[PIPE_SHADER_TESS_CTRL];
1266      FALLTHROUGH;
1267   case PIPE_SHADER_TESS_CTRL:
1268      if (ctx->gfx_stages[PIPE_SHADER_TESS_EVAL])
1269         return ctx->gfx_stages[PIPE_SHADER_TESS_EVAL];
1270      FALLTHROUGH;
1271   case PIPE_SHADER_TESS_EVAL:
1272      if (ctx->gfx_stages[PIPE_SHADER_GEOMETRY])
1273         return ctx->gfx_stages[PIPE_SHADER_GEOMETRY];
1274      FALLTHROUGH;
1275   case PIPE_SHADER_GEOMETRY:
1276      return ctx->gfx_stages[PIPE_SHADER_FRAGMENT];
1277   case PIPE_SHADER_FRAGMENT:
1278      return NULL;
1279   default:
1280      unreachable("shader type not supported");
1281   }
1282}
1283
1284enum tex_scan_flags {
1285   TEX_SAMPLE_INTEGER_TEXTURE = 1 << 0,
1286   TEX_CMP_WITH_LOD_BIAS_GRAD = 1 << 1,
1287   TEX_SCAN_ALL_FLAGS         = (1 << 2) - 1
1288};
1289
1290static unsigned
1291scan_texture_use(nir_shader *nir)
1292{
1293   unsigned result = 0;
1294   nir_foreach_function(func, nir) {
1295      nir_foreach_block(block, func->impl) {
1296         nir_foreach_instr(instr, block) {
1297            if (instr->type == nir_instr_type_tex) {
1298               auto tex = nir_instr_as_tex(instr);
1299               switch (tex->op) {
1300               case nir_texop_txb:
1301               case nir_texop_txl:
1302               case nir_texop_txd:
1303                  if (tex->is_shadow)
1304                     result |= TEX_CMP_WITH_LOD_BIAS_GRAD;
1305                  FALLTHROUGH;
1306               case nir_texop_tex:
1307                  if (tex->dest_type & (nir_type_int | nir_type_uint))
1308                     result |= TEX_SAMPLE_INTEGER_TEXTURE;
1309               default:
1310                  ;
1311               }
1312            }
1313            if (TEX_SCAN_ALL_FLAGS == result)
1314               return result;
1315         }
1316      }
1317   }
1318   return result;
1319}
1320
1321static uint64_t
1322update_so_info(struct pipe_stream_output_info *so_info,
1323               uint64_t outputs_written)
1324{
1325   uint64_t so_outputs = 0;
1326   uint8_t reverse_map[64] = {0};
1327   unsigned slot = 0;
1328
1329   while (outputs_written)
1330      reverse_map[slot++] = u_bit_scan64(&outputs_written);
1331
1332   for (unsigned i = 0; i < so_info->num_outputs; i++) {
1333      struct pipe_stream_output *output = &so_info->output[i];
1334
1335      /* Map Gallium's condensed "slots" back to real VARYING_SLOT_* enums */
1336      output->register_index = reverse_map[output->register_index];
1337
1338      so_outputs |= 1ull << output->register_index;
1339   }
1340
1341   return so_outputs;
1342}
1343
1344static struct d3d12_shader_selector *
1345d3d12_create_shader_impl(struct d3d12_context *ctx,
1346                         struct d3d12_shader_selector *sel,
1347                         struct nir_shader *nir,
1348                         struct d3d12_shader_selector *prev,
1349                         struct d3d12_shader_selector *next)
1350{
1351   unsigned tex_scan_result = scan_texture_use(nir);
1352   sel->samples_int_textures = (tex_scan_result & TEX_SAMPLE_INTEGER_TEXTURE) != 0;
1353   sel->compare_with_lod_bias_grad = (tex_scan_result & TEX_CMP_WITH_LOD_BIAS_GRAD) != 0;
1354   sel->workgroup_size_variable = nir->info.workgroup_size_variable;
1355
1356   /* Integer cube maps are not supported in DirectX because sampling is not supported
1357    * on integer textures and TextureLoad is not supported for cube maps, so we have to
1358    * lower integer cube maps to be handled like 2D textures arrays*/
1359   NIR_PASS_V(nir, dxil_nir_lower_int_cubemaps, true);
1360
1361   /* Keep this initial shader as the blue print for possible variants */
1362   sel->initial = nir;
1363
1364   /*
1365    * We must compile some shader here, because if the previous or a next shaders exists later
1366    * when the shaders are bound, then the key evaluation in the shader selector will access
1367    * the current variant of these  prev and next shader, and we can only assign
1368    * a current variant when it has been successfully compiled.
1369    *
1370    * For shaders that require lowering because certain instructions are not available
1371    * and their emulation is state depended (like sampling an integer texture that must be
1372    * emulated and needs handling of boundary conditions, or shadow compare sampling with LOD),
1373    * we must go through the shader selector here to create a compilable variant.
1374    * For shaders that are not depended on the state this is just compiling the original
1375    * shader.
1376    *
1377    * TODO: get rid of having to compiling the shader here if it can be forseen that it will
1378    * be thrown away (i.e. it depends on states that are likely to change before the shader is
1379    * used for the first time)
1380    */
1381   struct d3d12_selection_context sel_ctx = {0};
1382   sel_ctx.ctx = ctx;
1383   select_shader_variant(&sel_ctx, sel, prev, next);
1384
1385   if (!sel->current) {
1386      ralloc_free(sel);
1387      return NULL;
1388   }
1389
1390   return sel;
1391}
1392
1393struct d3d12_shader_selector *
1394d3d12_create_shader(struct d3d12_context *ctx,
1395                    pipe_shader_type stage,
1396                    const struct pipe_shader_state *shader)
1397{
1398   struct d3d12_shader_selector *sel = rzalloc(nullptr, d3d12_shader_selector);
1399   sel->stage = stage;
1400
1401   struct nir_shader *nir = NULL;
1402
1403   if (shader->type == PIPE_SHADER_IR_NIR) {
1404      nir = (nir_shader *)shader->ir.nir;
1405   } else {
1406      assert(shader->type == PIPE_SHADER_IR_TGSI);
1407      nir = tgsi_to_nir(shader->tokens, ctx->base.screen, false);
1408   }
1409
1410   nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
1411   memcpy(&sel->so_info, &shader->stream_output, sizeof(sel->so_info));
1412   update_so_info(&sel->so_info, nir->info.outputs_written);
1413
1414   assert(nir != NULL);
1415   d3d12_shader_selector *prev = get_prev_shader(ctx, sel->stage);
1416   d3d12_shader_selector *next = get_next_shader(ctx, sel->stage);
1417
1418   NIR_PASS_V(nir, dxil_nir_split_clip_cull_distance);
1419   NIR_PASS_V(nir, d3d12_split_multistream_varyings);
1420
1421   if (nir->info.stage != MESA_SHADER_VERTEX)
1422      nir->info.inputs_read =
1423            dxil_reassign_driver_locations(nir, nir_var_shader_in,
1424                                            prev ? prev->current->nir->info.outputs_written : 0);
1425   else
1426      nir->info.inputs_read = dxil_sort_by_driver_location(nir, nir_var_shader_in);
1427
1428   if (nir->info.stage != MESA_SHADER_FRAGMENT) {
1429      nir->info.outputs_written =
1430            dxil_reassign_driver_locations(nir, nir_var_shader_out,
1431                                            next ? next->current->nir->info.inputs_read : 0);
1432   } else {
1433      NIR_PASS_V(nir, nir_lower_fragcoord_wtrans);
1434      NIR_PASS_V(nir, d3d12_lower_sample_pos);
1435      dxil_sort_ps_outputs(nir);
1436   }
1437
1438   return d3d12_create_shader_impl(ctx, sel, nir, prev, next);
1439}
1440
1441struct d3d12_shader_selector *
1442d3d12_create_compute_shader(struct d3d12_context *ctx,
1443                            const struct pipe_compute_state *shader)
1444{
1445   struct d3d12_shader_selector *sel = rzalloc(nullptr, d3d12_shader_selector);
1446   sel->stage = PIPE_SHADER_COMPUTE;
1447
1448   struct nir_shader *nir = NULL;
1449
1450   if (shader->ir_type == PIPE_SHADER_IR_NIR) {
1451      nir = (nir_shader *)shader->prog;
1452   } else {
1453      assert(shader->ir_type == PIPE_SHADER_IR_TGSI);
1454      nir = tgsi_to_nir(shader->prog, ctx->base.screen, false);
1455   }
1456
1457   nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
1458
1459   NIR_PASS_V(nir, d3d12_lower_compute_state_vars);
1460
1461   return d3d12_create_shader_impl(ctx, sel, nir, nullptr, nullptr);
1462}
1463
1464void
1465d3d12_select_shader_variants(struct d3d12_context *ctx, const struct pipe_draw_info *dinfo)
1466{
1467   static unsigned order[] = {
1468      PIPE_SHADER_VERTEX,
1469      PIPE_SHADER_TESS_CTRL,
1470      PIPE_SHADER_TESS_EVAL,
1471      PIPE_SHADER_GEOMETRY,
1472      PIPE_SHADER_FRAGMENT
1473   };
1474   struct d3d12_selection_context sel_ctx;
1475
1476   sel_ctx.ctx = ctx;
1477   sel_ctx.needs_point_sprite_lowering = needs_point_sprite_lowering(ctx, dinfo);
1478   sel_ctx.fill_mode_lowered = fill_mode_lowered(ctx, dinfo);
1479   sel_ctx.cull_mode_lowered = cull_mode_lowered(ctx, sel_ctx.fill_mode_lowered);
1480   sel_ctx.provoking_vertex = get_provoking_vertex(&sel_ctx, &sel_ctx.alternate_tri, dinfo);
1481   sel_ctx.needs_vertex_reordering = needs_vertex_reordering(&sel_ctx, dinfo);
1482   sel_ctx.missing_dual_src_outputs = missing_dual_src_outputs(ctx);
1483   sel_ctx.frag_result_color_lowering = frag_result_color_lowering(ctx);
1484   sel_ctx.manual_depth_range = manual_depth_range(ctx);
1485
1486   validate_geometry_shader_variant(&sel_ctx);
1487   validate_tess_ctrl_shader_variant(&sel_ctx);
1488
1489   for (unsigned i = 0; i < ARRAY_SIZE(order); ++i) {
1490      auto sel = ctx->gfx_stages[order[i]];
1491      if (!sel)
1492         continue;
1493
1494      d3d12_shader_selector *prev = get_prev_shader(ctx, sel->stage);
1495      d3d12_shader_selector *next = get_next_shader(ctx, sel->stage);
1496
1497      select_shader_variant(&sel_ctx, sel, prev, next);
1498   }
1499}
1500
1501static const unsigned *
1502workgroup_size_variable(struct d3d12_context *ctx,
1503                        const struct pipe_grid_info *info)
1504{
1505   if (ctx->compute_state->workgroup_size_variable)
1506      return info->block;
1507   return nullptr;
1508}
1509
1510void
1511d3d12_select_compute_shader_variants(struct d3d12_context *ctx, const struct pipe_grid_info *info)
1512{
1513   struct d3d12_selection_context sel_ctx = {};
1514
1515   sel_ctx.ctx = ctx;
1516   sel_ctx.variable_workgroup_size = workgroup_size_variable(ctx, info);
1517
1518   select_shader_variant(&sel_ctx, ctx->compute_state, nullptr, nullptr);
1519}
1520
1521void
1522d3d12_shader_free(struct d3d12_shader_selector *sel)
1523{
1524   auto shader = sel->first;
1525   while (shader) {
1526      free(shader->bytecode);
1527      shader = shader->next_variant;
1528   }
1529   ralloc_free(sel->initial);
1530   ralloc_free(sel);
1531}
1532