1/*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24#include "zink_program.h"
25
26#include "zink_compiler.h"
27#include "zink_context.h"
28#include "zink_descriptors.h"
29#include "zink_helpers.h"
30#include "zink_render_pass.h"
31#include "zink_resource.h"
32#include "zink_screen.h"
33#include "zink_state.h"
34#include "zink_inlines.h"
35
36#include "util/hash_table.h"
37#include "util/set.h"
38#include "util/u_debug.h"
39#include "util/u_memory.h"
40#include "util/u_prim.h"
41#include "tgsi/tgsi_from_mesa.h"
42
43/* for pipeline cache */
44#define XXH_INLINE_ALL
45#include "util/xxhash.h"
46
47struct gfx_pipeline_cache_entry {
48   struct zink_gfx_pipeline_state state;
49   VkPipeline pipeline;
50};
51
52struct compute_pipeline_cache_entry {
53   struct zink_compute_pipeline_state state;
54   VkPipeline pipeline;
55};
56
57void
58debug_describe_zink_gfx_program(char *buf, const struct zink_gfx_program *ptr)
59{
60   sprintf(buf, "zink_gfx_program");
61}
62
63void
64debug_describe_zink_compute_program(char *buf, const struct zink_compute_program *ptr)
65{
66   sprintf(buf, "zink_compute_program");
67}
68
69static bool
70shader_key_matches(const struct zink_shader_module *zm, bool ignore_size,
71                   const struct zink_shader_key *key, unsigned num_uniforms)
72{
73   bool key_size_differs = ignore_size ? false : zm->key_size != key->size;
74   if (key_size_differs || zm->num_uniforms != num_uniforms || zm->has_nonseamless != !!key->base.nonseamless_cube_mask)
75      return false;
76   const uint32_t nonseamless_size = zm->has_nonseamless ? sizeof(uint32_t) : 0;
77   return !memcmp(zm->key, key, zm->key_size) &&
78          (!nonseamless_size || !memcmp(zm->key + zm->key_size, &key->base.nonseamless_cube_mask, nonseamless_size)) &&
79          (!num_uniforms || !memcmp(zm->key + zm->key_size + nonseamless_size,
80                                    key->base.inlined_uniform_values, zm->num_uniforms * sizeof(uint32_t)));
81}
82
83static uint32_t
84shader_module_hash(const struct zink_shader_module *zm)
85{
86   const uint32_t nonseamless_size = zm->has_nonseamless ? sizeof(uint32_t) : 0;
87   unsigned key_size = zm->key_size + nonseamless_size + zm->num_uniforms * sizeof(uint32_t);
88   return _mesa_hash_data(zm->key, key_size);
89}
90
91static struct zink_shader_module *
92get_shader_module_for_stage(struct zink_context *ctx, struct zink_screen *screen,
93                            struct zink_shader *zs, struct zink_gfx_program *prog,
94                            struct zink_gfx_pipeline_state *state)
95{
96   gl_shader_stage stage = zs->nir->info.stage;
97   enum pipe_shader_type pstage = pipe_shader_type_from_mesa(stage);
98   VkShaderModule mod;
99   struct zink_shader_module *zm = NULL;
100   unsigned inline_size = 0, nonseamless_size = 0;
101   struct zink_shader_key *key = &state->shader_keys.key[pstage];
102   bool ignore_key_size = false;
103   if (pstage == PIPE_SHADER_TESS_CTRL && !zs->is_generated) {
104      /* non-generated tcs won't use the shader key */
105      ignore_key_size = true;
106   }
107   if (ctx && zs->nir->info.num_inlinable_uniforms &&
108       ctx->inlinable_uniforms_valid_mask & BITFIELD64_BIT(pstage)) {
109      if (zs->can_inline && (screen->is_cpu || prog->inlined_variant_count[pstage] < ZINK_MAX_INLINED_VARIANTS))
110         inline_size = zs->nir->info.num_inlinable_uniforms;
111      else
112         key->inline_uniforms = false;
113   }
114   if (key->base.nonseamless_cube_mask)
115      nonseamless_size = sizeof(uint32_t);
116
117   struct zink_shader_module *iter, *next;
118   LIST_FOR_EACH_ENTRY_SAFE(iter, next, &prog->shader_cache[pstage][!!nonseamless_size][!!inline_size], list) {
119      if (!shader_key_matches(iter, ignore_key_size, key, inline_size))
120         continue;
121      list_delinit(&iter->list);
122      zm = iter;
123      break;
124   }
125
126   if (!zm) {
127      zm = malloc(sizeof(struct zink_shader_module) + key->size + nonseamless_size + inline_size * sizeof(uint32_t));
128      if (!zm) {
129         return NULL;
130      }
131      unsigned patch_vertices = state->shader_keys.key[PIPE_SHADER_TESS_CTRL ].key.tcs.patch_vertices;
132      if (pstage == PIPE_SHADER_TESS_CTRL && zs->is_generated && zs->spirv) {
133         assert(ctx); //TODO async
134         mod = zink_shader_tcs_compile(screen, zs, patch_vertices);
135      } else {
136         mod = zink_shader_compile(screen, zs, prog->nir[stage], key);
137      }
138      if (!mod) {
139         FREE(zm);
140         return NULL;
141      }
142      zm->shader = mod;
143      list_inithead(&zm->list);
144      zm->num_uniforms = inline_size;
145      if (!ignore_key_size) {
146         zm->key_size = key->size;
147         memcpy(zm->key, key, key->size);
148      } else {
149         zm->key_size = 0;
150         memset(zm->key, 0, key->size);
151      }
152      if (nonseamless_size) {
153         /* nonseamless mask gets added to base key if it exists */
154         memcpy(zm->key + key->size, &key->base.nonseamless_cube_mask, nonseamless_size);
155      }
156      zm->has_nonseamless = !!nonseamless_size;
157      if (inline_size)
158         memcpy(zm->key + key->size + nonseamless_size, key->base.inlined_uniform_values, inline_size * sizeof(uint32_t));
159      if (pstage == PIPE_SHADER_TESS_CTRL && zs->is_generated)
160         zm->hash = patch_vertices;
161      else
162         zm->hash = shader_module_hash(zm);
163      zm->default_variant = !inline_size && list_is_empty(&prog->shader_cache[pstage][0][0]);
164      if (inline_size)
165         prog->inlined_variant_count[pstage]++;
166   }
167   list_add(&zm->list, &prog->shader_cache[pstage][!!nonseamless_size][!!inline_size]);
168   return zm;
169}
170
171static void
172zink_destroy_shader_module(struct zink_screen *screen, struct zink_shader_module *zm)
173{
174   VKSCR(DestroyShaderModule)(screen->dev, zm->shader, NULL);
175   free(zm);
176}
177
178static void
179destroy_shader_cache(struct zink_screen *screen, struct list_head *sc)
180{
181   struct zink_shader_module *zm, *next;
182   LIST_FOR_EACH_ENTRY_SAFE(zm, next, sc, list) {
183      list_delinit(&zm->list);
184      zink_destroy_shader_module(screen, zm);
185   }
186}
187
188static void
189update_gfx_shader_modules(struct zink_context *ctx,
190                      struct zink_screen *screen,
191                      struct zink_gfx_program *prog, uint32_t mask,
192                      struct zink_gfx_pipeline_state *state)
193{
194   bool hash_changed = false;
195   bool default_variants = true;
196   bool first = !prog->modules[PIPE_SHADER_VERTEX];
197   uint32_t variant_hash = prog->last_variant_hash;
198   u_foreach_bit(pstage, mask) {
199      assert(prog->shaders[pstage]);
200      struct zink_shader_module *zm = get_shader_module_for_stage(ctx, screen, prog->shaders[pstage], prog, state);
201      state->modules[pstage] = zm->shader;
202      if (prog->modules[pstage] == zm)
203         continue;
204      if (prog->modules[pstage])
205         variant_hash ^= prog->modules[pstage]->hash;
206      hash_changed = true;
207      default_variants &= zm->default_variant;
208      prog->modules[pstage] = zm;
209      variant_hash ^= prog->modules[pstage]->hash;
210   }
211
212   if (hash_changed && state) {
213      if (default_variants && !first)
214         prog->last_variant_hash = prog->default_variant_hash;
215      else {
216         prog->last_variant_hash = variant_hash;
217         if (first) {
218            p_atomic_dec(&prog->base.reference.count);
219            prog->default_variant_hash = prog->last_variant_hash;
220         }
221      }
222
223      state->modules_changed = true;
224   }
225}
226
227static uint32_t
228hash_gfx_pipeline_state(const void *key)
229{
230   const struct zink_gfx_pipeline_state *state = key;
231   uint32_t hash = _mesa_hash_data(key, offsetof(struct zink_gfx_pipeline_state, hash));
232   if (!state->have_EXT_extended_dynamic_state2)
233      hash = XXH32(&state->dyn_state2, sizeof(state->dyn_state2), hash);
234   if (state->have_EXT_extended_dynamic_state)
235      return hash;
236   return XXH32(&state->dyn_state1, sizeof(state->dyn_state1), hash);
237}
238
239static bool
240equals_gfx_pipeline_state(const void *a, const void *b)
241{
242   const struct zink_gfx_pipeline_state *sa = a;
243   const struct zink_gfx_pipeline_state *sb = b;
244   if (sa->uses_dynamic_stride != sb->uses_dynamic_stride)
245      return false;
246   /* dynamic vs rp */
247   if (!!sa->render_pass != !!sb->render_pass)
248      return false;
249   if (!sa->have_EXT_extended_dynamic_state || !sa->uses_dynamic_stride) {
250      if (sa->vertex_buffers_enabled_mask != sb->vertex_buffers_enabled_mask)
251         return false;
252      /* if we don't have dynamic states, we have to hash the enabled vertex buffer bindings */
253      uint32_t mask_a = sa->vertex_buffers_enabled_mask;
254      uint32_t mask_b = sb->vertex_buffers_enabled_mask;
255      while (mask_a || mask_b) {
256         unsigned idx_a = u_bit_scan(&mask_a);
257         unsigned idx_b = u_bit_scan(&mask_b);
258         if (sa->vertex_strides[idx_a] != sb->vertex_strides[idx_b])
259            return false;
260      }
261   }
262   if (!sa->have_EXT_extended_dynamic_state) {
263      if (memcmp(&sa->dyn_state1, &sb->dyn_state1, offsetof(struct zink_pipeline_dynamic_state1, depth_stencil_alpha_state)))
264         return false;
265      if (!!sa->dyn_state1.depth_stencil_alpha_state != !!sb->dyn_state1.depth_stencil_alpha_state ||
266          (sa->dyn_state1.depth_stencil_alpha_state &&
267           memcmp(sa->dyn_state1.depth_stencil_alpha_state, sb->dyn_state1.depth_stencil_alpha_state,
268                  sizeof(struct zink_depth_stencil_alpha_hw_state))))
269         return false;
270   }
271   if (!sa->have_EXT_extended_dynamic_state2) {
272      if (memcmp(&sa->dyn_state2, &sb->dyn_state2, sizeof(sa->dyn_state2)))
273         return false;
274   } else if (!sa->extendedDynamicState2PatchControlPoints) {
275      if (sa->dyn_state2.vertices_per_patch != sb->dyn_state2.vertices_per_patch)
276         return false;
277   }
278   return !memcmp(sa->modules, sb->modules, sizeof(sa->modules)) &&
279          !memcmp(a, b, offsetof(struct zink_gfx_pipeline_state, hash));
280}
281
282void
283zink_update_gfx_program(struct zink_context *ctx, struct zink_gfx_program *prog)
284{
285   update_gfx_shader_modules(ctx, zink_screen(ctx->base.screen), prog, ctx->dirty_shader_stages & prog->stages_present, &ctx->gfx_pipeline_state);
286}
287
288static void
289update_cs_shader_module(struct zink_context *ctx, struct zink_compute_program *comp)
290{
291   struct zink_screen *screen = zink_screen(ctx->base.screen);
292   struct zink_shader *zs = comp->shader;
293   VkShaderModule mod;
294   struct zink_shader_module *zm = NULL;
295   unsigned inline_size = 0, nonseamless_size = 0;
296   struct zink_shader_key *key = &ctx->compute_pipeline_state.key;
297
298   if (ctx && zs->nir->info.num_inlinable_uniforms &&
299       ctx->inlinable_uniforms_valid_mask & BITFIELD64_BIT(PIPE_SHADER_COMPUTE)) {
300      if (screen->is_cpu || comp->inlined_variant_count < ZINK_MAX_INLINED_VARIANTS)
301         inline_size = zs->nir->info.num_inlinable_uniforms;
302      else
303         key->inline_uniforms = false;
304   }
305   if (key->base.nonseamless_cube_mask)
306      nonseamless_size = sizeof(uint32_t);
307
308   if (inline_size || nonseamless_size) {
309      struct zink_shader_module *iter, *next;
310      LIST_FOR_EACH_ENTRY_SAFE(iter, next, &comp->shader_cache[!!nonseamless_size], list) {
311         if (!shader_key_matches(iter, false, key, inline_size))
312            continue;
313         list_delinit(&iter->list);
314         zm = iter;
315         break;
316      }
317   } else {
318      zm = comp->module;
319   }
320
321   if (!zm) {
322      zm = malloc(sizeof(struct zink_shader_module) + nonseamless_size + inline_size * sizeof(uint32_t));
323      if (!zm) {
324         return;
325      }
326      mod = zink_shader_compile(screen, zs, comp->shader->nir, key);
327      if (!mod) {
328         FREE(zm);
329         return;
330      }
331      zm->shader = mod;
332      list_inithead(&zm->list);
333      zm->num_uniforms = inline_size;
334      zm->key_size = 0;
335      zm->has_nonseamless = !!nonseamless_size;
336      assert(nonseamless_size || inline_size);
337      if (nonseamless_size)
338         memcpy(zm->key, &key->base.nonseamless_cube_mask, nonseamless_size);
339      if (inline_size)
340         memcpy(zm->key + nonseamless_size, key->base.inlined_uniform_values, inline_size * sizeof(uint32_t));
341      zm->hash = shader_module_hash(zm);
342      zm->default_variant = false;
343      if (inline_size)
344         comp->inlined_variant_count++;
345   }
346   if (zm->num_uniforms || nonseamless_size)
347      list_add(&zm->list, &comp->shader_cache[!!nonseamless_size]);
348   if (comp->curr == zm)
349      return;
350   ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
351   comp->curr = zm;
352   ctx->compute_pipeline_state.module_hash = zm->hash;
353   ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
354   ctx->compute_pipeline_state.module_changed = true;
355}
356
357void
358zink_update_compute_program(struct zink_context *ctx)
359{
360   update_cs_shader_module(ctx, ctx->curr_compute);
361}
362
363VkPipelineLayout
364zink_pipeline_layout_create(struct zink_screen *screen, struct zink_program *pg, uint32_t *compat)
365{
366   VkPipelineLayoutCreateInfo plci = {0};
367   plci.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
368
369   plci.pSetLayouts = pg->dsl;
370   plci.setLayoutCount = pg->num_dsl;
371
372   VkPushConstantRange pcr[2] = {0};
373   if (pg->is_compute) {
374      if (((struct zink_compute_program*)pg)->shader->nir->info.stage == MESA_SHADER_KERNEL) {
375         pcr[0].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
376         pcr[0].offset = 0;
377         pcr[0].size = sizeof(struct zink_cs_push_constant);
378         plci.pushConstantRangeCount = 1;
379      }
380   } else {
381      pcr[0].stageFlags = VK_SHADER_STAGE_VERTEX_BIT;
382      pcr[0].offset = offsetof(struct zink_gfx_push_constant, draw_mode_is_indexed);
383      pcr[0].size = 2 * sizeof(unsigned);
384      pcr[1].stageFlags = VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
385      pcr[1].offset = offsetof(struct zink_gfx_push_constant, default_inner_level);
386      pcr[1].size = sizeof(float) * 6;
387      plci.pushConstantRangeCount = 2;
388   }
389   plci.pPushConstantRanges = &pcr[0];
390
391   VkPipelineLayout layout;
392   VkResult result = VKSCR(CreatePipelineLayout)(screen->dev, &plci, NULL, &layout);
393   if (result != VK_SUCCESS) {
394      mesa_loge("vkCreatePipelineLayout failed (%s)", vk_Result_to_str(result));
395      return VK_NULL_HANDLE;
396   }
397
398   *compat = _mesa_hash_data(pg->dsl, pg->num_dsl * sizeof(pg->dsl[0]));
399
400   return layout;
401}
402
403static void
404assign_io(struct zink_gfx_program *prog, struct zink_shader *stages[ZINK_SHADER_COUNT])
405{
406   struct zink_shader *shaders[PIPE_SHADER_TYPES];
407
408   /* build array in pipeline order */
409   for (unsigned i = 0; i < ZINK_SHADER_COUNT; i++)
410      shaders[tgsi_processor_to_shader_stage(i)] = stages[i];
411
412   for (unsigned i = 0; i < MESA_SHADER_FRAGMENT;) {
413      nir_shader *producer = shaders[i]->nir;
414      for (unsigned j = i + 1; j < ZINK_SHADER_COUNT; i++, j++) {
415         struct zink_shader *consumer = shaders[j];
416         if (!consumer)
417            continue;
418         if (!prog->nir[producer->info.stage])
419            prog->nir[producer->info.stage] = nir_shader_clone(prog, producer);
420         if (!prog->nir[j])
421            prog->nir[j] = nir_shader_clone(prog, consumer->nir);
422         zink_compiler_assign_io(prog->nir[producer->info.stage], prog->nir[j]);
423         i = j;
424         break;
425      }
426   }
427}
428
429struct zink_gfx_program *
430zink_create_gfx_program(struct zink_context *ctx,
431                        struct zink_shader *stages[ZINK_SHADER_COUNT],
432                        unsigned vertices_per_patch)
433{
434   struct zink_screen *screen = zink_screen(ctx->base.screen);
435   struct zink_gfx_program *prog = rzalloc(NULL, struct zink_gfx_program);
436   if (!prog)
437      goto fail;
438
439   pipe_reference_init(&prog->base.reference, 1);
440   util_queue_fence_init(&prog->base.cache_fence);
441
442   for (int i = 0; i < ZINK_SHADER_COUNT; ++i) {
443      list_inithead(&prog->shader_cache[i][0][0]);
444      list_inithead(&prog->shader_cache[i][0][1]);
445      list_inithead(&prog->shader_cache[i][1][0]);
446      list_inithead(&prog->shader_cache[i][1][1]);
447      if (stages[i]) {
448         prog->shaders[i] = stages[i];
449         prog->stages_present |= BITFIELD_BIT(i);
450      }
451   }
452   if (stages[PIPE_SHADER_TESS_EVAL] && !stages[PIPE_SHADER_TESS_CTRL]) {
453      prog->shaders[PIPE_SHADER_TESS_EVAL]->generated =
454      prog->shaders[PIPE_SHADER_TESS_CTRL] =
455        zink_shader_tcs_create(screen, stages[PIPE_SHADER_VERTEX], vertices_per_patch);
456      prog->stages_present |= BITFIELD_BIT(PIPE_SHADER_TESS_CTRL);
457   }
458
459   assign_io(prog, prog->shaders);
460
461   if (stages[PIPE_SHADER_GEOMETRY])
462      prog->last_vertex_stage = stages[PIPE_SHADER_GEOMETRY];
463   else if (stages[PIPE_SHADER_TESS_EVAL])
464      prog->last_vertex_stage = stages[PIPE_SHADER_TESS_EVAL];
465   else
466      prog->last_vertex_stage = stages[PIPE_SHADER_VERTEX];
467
468   for (int i = 0; i < ARRAY_SIZE(prog->pipelines); ++i) {
469      _mesa_hash_table_init(&prog->pipelines[i], prog, NULL, equals_gfx_pipeline_state);
470      /* only need first 3/4 for point/line/tri/patch */
471      if (screen->info.have_EXT_extended_dynamic_state &&
472          i == (prog->last_vertex_stage->nir->info.stage == MESA_SHADER_TESS_EVAL ? 4 : 3))
473         break;
474   }
475
476   struct mesa_sha1 sctx;
477   _mesa_sha1_init(&sctx);
478   for (int i = 0; i < ZINK_SHADER_COUNT; ++i) {
479      if (prog->shaders[i]) {
480         simple_mtx_lock(&prog->shaders[i]->lock);
481         _mesa_set_add(prog->shaders[i]->programs, prog);
482         simple_mtx_unlock(&prog->shaders[i]->lock);
483         zink_gfx_program_reference(ctx, NULL, prog);
484         _mesa_sha1_update(&sctx, prog->shaders[i]->base.sha1, sizeof(prog->shaders[i]->base.sha1));
485      }
486   }
487   _mesa_sha1_final(&sctx, prog->base.sha1);
488
489   if (!screen->descriptor_program_init(ctx, &prog->base))
490      goto fail;
491
492   zink_screen_get_pipeline_cache(screen, &prog->base);
493   return prog;
494
495fail:
496   if (prog)
497      zink_destroy_gfx_program(ctx, prog);
498   return NULL;
499}
500
501static uint32_t
502hash_compute_pipeline_state(const void *key)
503{
504   const struct zink_compute_pipeline_state *state = key;
505   uint32_t hash = _mesa_hash_data(state, offsetof(struct zink_compute_pipeline_state, hash));
506   if (state->use_local_size)
507      hash = XXH32(&state->local_size[0], sizeof(state->local_size), hash);
508   return hash;
509}
510
511void
512zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink_compute_program *comp, const uint block[3])
513{
514   struct zink_shader *zs = comp->shader;
515   bool use_local_size = !(zs->nir->info.workgroup_size[0] ||
516                           zs->nir->info.workgroup_size[1] ||
517                           zs->nir->info.workgroup_size[2]);
518   if (ctx->compute_pipeline_state.use_local_size != use_local_size)
519      ctx->compute_pipeline_state.dirty = true;
520   ctx->compute_pipeline_state.use_local_size = use_local_size;
521
522   if (ctx->compute_pipeline_state.use_local_size) {
523      for (int i = 0; i < ARRAY_SIZE(ctx->compute_pipeline_state.local_size); i++) {
524         if (ctx->compute_pipeline_state.local_size[i] != block[i])
525            ctx->compute_pipeline_state.dirty = true;
526         ctx->compute_pipeline_state.local_size[i] = block[i];
527      }
528   } else
529      ctx->compute_pipeline_state.local_size[0] =
530      ctx->compute_pipeline_state.local_size[1] =
531      ctx->compute_pipeline_state.local_size[2] = 0;
532}
533
534static bool
535equals_compute_pipeline_state(const void *a, const void *b)
536{
537   const struct zink_compute_pipeline_state *sa = a;
538   const struct zink_compute_pipeline_state *sb = b;
539   return !memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) &&
540          sa->module == sb->module;
541}
542
543struct zink_compute_program *
544zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader)
545{
546   struct zink_screen *screen = zink_screen(ctx->base.screen);
547   struct zink_compute_program *comp = rzalloc(NULL, struct zink_compute_program);
548   if (!comp)
549      goto fail;
550
551   pipe_reference_init(&comp->base.reference, 1);
552   util_queue_fence_init(&comp->base.cache_fence);
553   comp->base.is_compute = true;
554
555   comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module);
556   assert(comp->module);
557   comp->module->shader = zink_shader_compile(screen, shader, shader->nir, NULL);
558   assert(comp->module->shader);
559   list_inithead(&comp->shader_cache[0]);
560   list_inithead(&comp->shader_cache[1]);
561
562   comp->pipelines = _mesa_hash_table_create(NULL, NULL,
563                                             equals_compute_pipeline_state);
564
565   _mesa_set_add(shader->programs, comp);
566   comp->shader = shader;
567   memcpy(comp->base.sha1, shader->base.sha1, sizeof(shader->base.sha1));
568
569   if (!screen->descriptor_program_init(ctx, &comp->base))
570      goto fail;
571
572   zink_screen_get_pipeline_cache(screen, &comp->base);
573   return comp;
574
575fail:
576   if (comp)
577      zink_destroy_compute_program(ctx, comp);
578   return NULL;
579}
580
581uint32_t
582zink_program_get_descriptor_usage(struct zink_context *ctx, enum pipe_shader_type stage, enum zink_descriptor_type type)
583{
584   struct zink_shader *zs = NULL;
585   switch (stage) {
586   case PIPE_SHADER_VERTEX:
587   case PIPE_SHADER_TESS_CTRL:
588   case PIPE_SHADER_TESS_EVAL:
589   case PIPE_SHADER_GEOMETRY:
590   case PIPE_SHADER_FRAGMENT:
591      zs = ctx->gfx_stages[stage];
592      break;
593   case PIPE_SHADER_COMPUTE: {
594      zs = ctx->compute_stage;
595      break;
596   }
597   default:
598      unreachable("unknown shader type");
599   }
600   if (!zs)
601      return 0;
602   switch (type) {
603   case ZINK_DESCRIPTOR_TYPE_UBO:
604      return zs->ubos_used;
605   case ZINK_DESCRIPTOR_TYPE_SSBO:
606      return zs->ssbos_used;
607   case ZINK_DESCRIPTOR_TYPE_SAMPLER_VIEW:
608      return BITSET_TEST_RANGE(zs->nir->info.textures_used, 0, PIPE_MAX_SAMPLERS - 1);
609   case ZINK_DESCRIPTOR_TYPE_IMAGE:
610      return BITSET_TEST_RANGE(zs->nir->info.images_used, 0, PIPE_MAX_SAMPLERS - 1);
611   default:
612      unreachable("unknown descriptor type!");
613   }
614   return 0;
615}
616
617bool
618zink_program_descriptor_is_buffer(struct zink_context *ctx, enum pipe_shader_type stage, enum zink_descriptor_type type, unsigned i)
619{
620   struct zink_shader *zs = NULL;
621   switch (stage) {
622   case PIPE_SHADER_VERTEX:
623   case PIPE_SHADER_TESS_CTRL:
624   case PIPE_SHADER_TESS_EVAL:
625   case PIPE_SHADER_GEOMETRY:
626   case PIPE_SHADER_FRAGMENT:
627      zs = ctx->gfx_stages[stage];
628      break;
629   case PIPE_SHADER_COMPUTE: {
630      zs = ctx->compute_stage;
631      break;
632   }
633   default:
634      unreachable("unknown shader type");
635   }
636   if (!zs)
637      return false;
638   return zink_shader_descriptor_is_buffer(zs, type, i);
639}
640
641static unsigned
642get_num_bindings(struct zink_shader *zs, enum zink_descriptor_type type)
643{
644   switch (type) {
645   case ZINK_DESCRIPTOR_TYPE_UBO:
646   case ZINK_DESCRIPTOR_TYPE_SSBO:
647      return zs->num_bindings[type];
648   default:
649      break;
650   }
651   unsigned num_bindings = 0;
652   for (int i = 0; i < zs->num_bindings[type]; i++)
653      num_bindings += zs->bindings[type][i].size;
654   return num_bindings;
655}
656
657unsigned
658zink_program_num_bindings_typed(const struct zink_program *pg, enum zink_descriptor_type type, bool is_compute)
659{
660   unsigned num_bindings = 0;
661   if (is_compute) {
662      struct zink_compute_program *comp = (void*)pg;
663      return get_num_bindings(comp->shader, type);
664   }
665   struct zink_gfx_program *prog = (void*)pg;
666   for (unsigned i = 0; i < ZINK_SHADER_COUNT; i++) {
667      if (prog->shaders[i])
668         num_bindings += get_num_bindings(prog->shaders[i], type);
669   }
670   return num_bindings;
671}
672
673unsigned
674zink_program_num_bindings(const struct zink_program *pg, bool is_compute)
675{
676   unsigned num_bindings = 0;
677   for (unsigned i = 0; i < ZINK_DESCRIPTOR_TYPES; i++)
678      num_bindings += zink_program_num_bindings_typed(pg, i, is_compute);
679   return num_bindings;
680}
681
682void
683zink_destroy_gfx_program(struct zink_context *ctx,
684                         struct zink_gfx_program *prog)
685{
686   struct zink_screen *screen = zink_screen(ctx->base.screen);
687   util_queue_fence_wait(&prog->base.cache_fence);
688   if (prog->base.layout)
689      VKSCR(DestroyPipelineLayout)(screen->dev, prog->base.layout, NULL);
690
691   for (int i = 0; i < ZINK_SHADER_COUNT; ++i) {
692      if (prog->shaders[i]) {
693         _mesa_set_remove_key(prog->shaders[i]->programs, prog);
694         prog->shaders[i] = NULL;
695      }
696      destroy_shader_cache(screen, &prog->shader_cache[i][0][0]);
697      destroy_shader_cache(screen, &prog->shader_cache[i][0][1]);
698      destroy_shader_cache(screen, &prog->shader_cache[i][1][0]);
699      destroy_shader_cache(screen, &prog->shader_cache[i][1][1]);
700      ralloc_free(prog->nir[i]);
701   }
702
703   unsigned max_idx = ARRAY_SIZE(prog->pipelines);
704   if (screen->info.have_EXT_extended_dynamic_state) {
705      /* only need first 3/4 for point/line/tri/patch */
706      if ((prog->stages_present &
707          (BITFIELD_BIT(PIPE_SHADER_TESS_EVAL) | BITFIELD_BIT(PIPE_SHADER_GEOMETRY))) ==
708          BITFIELD_BIT(PIPE_SHADER_TESS_EVAL))
709         max_idx = 4;
710      else
711         max_idx = 3;
712      max_idx++;
713   }
714
715   for (int i = 0; i < max_idx; ++i) {
716      hash_table_foreach(&prog->pipelines[i], entry) {
717         struct gfx_pipeline_cache_entry *pc_entry = entry->data;
718
719         VKSCR(DestroyPipeline)(screen->dev, pc_entry->pipeline, NULL);
720         free(pc_entry);
721      }
722   }
723   if (prog->base.pipeline_cache)
724      VKSCR(DestroyPipelineCache)(screen->dev, prog->base.pipeline_cache, NULL);
725   screen->descriptor_program_deinit(ctx, &prog->base);
726
727   ralloc_free(prog);
728}
729
730void
731zink_destroy_compute_program(struct zink_context *ctx,
732                             struct zink_compute_program *comp)
733{
734   struct zink_screen *screen = zink_screen(ctx->base.screen);
735   util_queue_fence_wait(&comp->base.cache_fence);
736   if (comp->base.layout)
737      VKSCR(DestroyPipelineLayout)(screen->dev, comp->base.layout, NULL);
738
739   if (comp->shader)
740      _mesa_set_remove_key(comp->shader->programs, comp);
741   destroy_shader_cache(screen, &comp->shader_cache[0]);
742   destroy_shader_cache(screen, &comp->shader_cache[1]);
743
744   hash_table_foreach(comp->pipelines, entry) {
745      struct compute_pipeline_cache_entry *pc_entry = entry->data;
746
747      VKSCR(DestroyPipeline)(screen->dev, pc_entry->pipeline, NULL);
748      free(pc_entry);
749   }
750   _mesa_hash_table_destroy(comp->pipelines, NULL);
751   VKSCR(DestroyShaderModule)(screen->dev, comp->module->shader, NULL);
752   free(comp->module);
753   if (comp->base.pipeline_cache)
754      VKSCR(DestroyPipelineCache)(screen->dev, comp->base.pipeline_cache, NULL);
755   screen->descriptor_program_deinit(ctx, &comp->base);
756
757   ralloc_free(comp);
758}
759
760static unsigned
761get_pipeline_idx(bool have_EXT_extended_dynamic_state, enum pipe_prim_type mode, VkPrimitiveTopology vkmode)
762{
763   /* VK_DYNAMIC_STATE_PRIMITIVE_TOPOLOGY specifies that the topology state in
764    * VkPipelineInputAssemblyStateCreateInfo only specifies the topology class,
765    * and the specific topology order and adjacency must be set dynamically
766    * with vkCmdSetPrimitiveTopology before any drawing commands.
767    */
768   if (have_EXT_extended_dynamic_state) {
769      if (mode == PIPE_PRIM_PATCHES)
770         return 3;
771      switch (u_reduced_prim(mode)) {
772      case PIPE_PRIM_POINTS:
773         return 0;
774      case PIPE_PRIM_LINES:
775         return 1;
776      default:
777         return 2;
778      }
779   }
780   return vkmode;
781}
782
783/*
784   VUID-vkCmdBindVertexBuffers2-pStrides-06209
785   If pStrides is not NULL each element of pStrides must be either 0 or greater than or equal
786   to the maximum extent of all vertex input attributes fetched from the corresponding
787   binding, where the extent is calculated as the VkVertexInputAttributeDescription::offset
788   plus VkVertexInputAttributeDescription::format size
789
790   * thus, if the stride doesn't meet the minimum requirement for a binding,
791   * disable the dynamic state here and use a fully-baked pipeline
792 */
793static bool
794check_vertex_strides(struct zink_context *ctx)
795{
796   const struct zink_vertex_elements_state *ves = ctx->element_state;
797   for (unsigned i = 0; i < ves->hw_state.num_bindings; i++) {
798      const struct pipe_vertex_buffer *vb = ctx->vertex_buffers + ves->binding_map[i];
799      unsigned stride = vb->buffer.resource ? vb->stride : 0;
800      if (stride && stride < ves->min_stride[i])
801         return false;
802   }
803   return true;
804}
805
806VkPipeline
807zink_get_gfx_pipeline(struct zink_context *ctx,
808                      struct zink_gfx_program *prog,
809                      struct zink_gfx_pipeline_state *state,
810                      enum pipe_prim_type mode)
811{
812   struct zink_screen *screen = zink_screen(ctx->base.screen);
813   const bool have_EXT_vertex_input_dynamic_state = screen->info.have_EXT_vertex_input_dynamic_state;
814   const bool have_EXT_extended_dynamic_state = screen->info.have_EXT_extended_dynamic_state;
815   bool uses_dynamic_stride = state->uses_dynamic_stride;
816
817   VkPrimitiveTopology vkmode = zink_primitive_topology(mode);
818   const unsigned idx = get_pipeline_idx(screen->info.have_EXT_extended_dynamic_state, mode, vkmode);
819   assert(idx <= ARRAY_SIZE(prog->pipelines));
820   if (!state->dirty && !state->modules_changed &&
821       (have_EXT_vertex_input_dynamic_state || !ctx->vertex_state_changed) &&
822       idx == state->idx)
823      return state->pipeline;
824
825   struct hash_entry *entry = NULL;
826
827   if (state->dirty) {
828      if (state->pipeline) //avoid on first hash
829         state->final_hash ^= state->hash;
830      state->hash = hash_gfx_pipeline_state(state);
831      state->final_hash ^= state->hash;
832      state->dirty = false;
833   }
834   if (!have_EXT_vertex_input_dynamic_state && ctx->vertex_state_changed) {
835      if (state->pipeline)
836         state->final_hash ^= state->vertex_hash;
837      if (have_EXT_extended_dynamic_state)
838         uses_dynamic_stride = check_vertex_strides(ctx);
839      if (!uses_dynamic_stride) {
840         uint32_t hash = 0;
841         /* if we don't have dynamic states, we have to hash the enabled vertex buffer bindings */
842         uint32_t vertex_buffers_enabled_mask = state->vertex_buffers_enabled_mask;
843         hash = XXH32(&vertex_buffers_enabled_mask, sizeof(uint32_t), hash);
844
845         for (unsigned i = 0; i < state->element_state->num_bindings; i++) {
846            const unsigned buffer_id = ctx->element_state->binding_map[i];
847            struct pipe_vertex_buffer *vb = ctx->vertex_buffers + buffer_id;
848            state->vertex_strides[buffer_id] = vb->buffer.resource ? vb->stride : 0;
849            hash = XXH32(&state->vertex_strides[buffer_id], sizeof(uint32_t), hash);
850         }
851         state->vertex_hash = hash ^ state->element_state->hash;
852      } else
853         state->vertex_hash = state->element_state->hash;
854      state->final_hash ^= state->vertex_hash;
855   }
856   state->modules_changed = false;
857   state->uses_dynamic_stride = uses_dynamic_stride;
858   ctx->vertex_state_changed = false;
859
860   entry = _mesa_hash_table_search_pre_hashed(&prog->pipelines[idx], state->final_hash, state);
861
862   if (!entry) {
863      util_queue_fence_wait(&prog->base.cache_fence);
864      VkPipeline pipeline = zink_create_gfx_pipeline(screen, prog, state,
865                                                     ctx->element_state->binding_map,
866                                                     vkmode);
867      if (pipeline == VK_NULL_HANDLE)
868         return VK_NULL_HANDLE;
869
870      zink_screen_update_pipeline_cache(screen, &prog->base);
871      struct gfx_pipeline_cache_entry *pc_entry = CALLOC_STRUCT(gfx_pipeline_cache_entry);
872      if (!pc_entry)
873         return VK_NULL_HANDLE;
874
875      memcpy(&pc_entry->state, state, sizeof(*state));
876      pc_entry->pipeline = pipeline;
877
878      entry = _mesa_hash_table_insert_pre_hashed(&prog->pipelines[idx], state->final_hash, pc_entry, pc_entry);
879      assert(entry);
880   }
881
882   struct gfx_pipeline_cache_entry *cache_entry = entry->data;
883   state->pipeline = cache_entry->pipeline;
884   state->idx = idx;
885   return state->pipeline;
886}
887
888VkPipeline
889zink_get_compute_pipeline(struct zink_screen *screen,
890                      struct zink_compute_program *comp,
891                      struct zink_compute_pipeline_state *state)
892{
893   struct hash_entry *entry = NULL;
894
895   if (!state->dirty && !state->module_changed)
896      return state->pipeline;
897   if (state->dirty) {
898      if (state->pipeline) //avoid on first hash
899         state->final_hash ^= state->hash;
900      state->hash = hash_compute_pipeline_state(state);
901      state->dirty = false;
902      state->final_hash ^= state->hash;
903   }
904   entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->final_hash, state);
905
906   if (!entry) {
907      util_queue_fence_wait(&comp->base.cache_fence);
908      VkPipeline pipeline = zink_create_compute_pipeline(screen, comp, state);
909
910      if (pipeline == VK_NULL_HANDLE)
911         return VK_NULL_HANDLE;
912
913      struct compute_pipeline_cache_entry *pc_entry = CALLOC_STRUCT(compute_pipeline_cache_entry);
914      if (!pc_entry)
915         return VK_NULL_HANDLE;
916
917      memcpy(&pc_entry->state, state, sizeof(*state));
918      pc_entry->pipeline = pipeline;
919
920      entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->final_hash, pc_entry, pc_entry);
921      assert(entry);
922   }
923
924   struct compute_pipeline_cache_entry *cache_entry = entry->data;
925   state->pipeline = cache_entry->pipeline;
926   return state->pipeline;
927}
928
929static inline void
930bind_stage(struct zink_context *ctx, enum pipe_shader_type stage,
931           struct zink_shader *shader)
932{
933   if (shader && shader->nir->info.num_inlinable_uniforms)
934      ctx->shader_has_inlinable_uniforms_mask |= 1 << stage;
935   else
936      ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage);
937
938   if (stage == PIPE_SHADER_COMPUTE) {
939      if (ctx->compute_stage) {
940         ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
941         ctx->compute_pipeline_state.module = VK_NULL_HANDLE;
942         ctx->compute_pipeline_state.module_hash = 0;
943      }
944      if (shader && shader != ctx->compute_stage) {
945         struct hash_entry *entry = _mesa_hash_table_search(&ctx->compute_program_cache, shader);
946         if (entry) {
947            ctx->compute_pipeline_state.dirty = true;
948            ctx->curr_compute = entry->data;
949         } else {
950            struct zink_compute_program *comp = zink_create_compute_program(ctx, shader);
951            _mesa_hash_table_insert(&ctx->compute_program_cache, comp->shader, comp);
952            ctx->compute_pipeline_state.dirty = true;
953            ctx->curr_compute = comp;
954            zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base);
955         }
956         ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash;
957         ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader;
958         ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
959         if (ctx->compute_pipeline_state.key.base.nonseamless_cube_mask)
960            ctx->dirty_shader_stages |= BITFIELD_BIT(PIPE_SHADER_COMPUTE);
961      } else if (!shader)
962         ctx->curr_compute = NULL;
963      ctx->compute_stage = shader;
964      zink_select_launch_grid(ctx);
965   } else {
966      if (ctx->gfx_stages[stage])
967         ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash;
968      ctx->gfx_stages[stage] = shader;
969      ctx->gfx_dirty = ctx->gfx_stages[PIPE_SHADER_FRAGMENT] && ctx->gfx_stages[PIPE_SHADER_VERTEX];
970      ctx->gfx_pipeline_state.modules_changed = true;
971      if (shader) {
972         ctx->shader_stages |= BITFIELD_BIT(stage);
973         ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash;
974      } else {
975         ctx->gfx_pipeline_state.modules[stage] = VK_NULL_HANDLE;
976         if (ctx->curr_program)
977            ctx->gfx_pipeline_state.final_hash ^= ctx->curr_program->last_variant_hash;
978         ctx->curr_program = NULL;
979         ctx->shader_stages &= ~BITFIELD_BIT(stage);
980      }
981   }
982}
983
984static void
985bind_last_vertex_stage(struct zink_context *ctx)
986{
987   enum pipe_shader_type old = ctx->last_vertex_stage ? pipe_shader_type_from_mesa(ctx->last_vertex_stage->nir->info.stage) : PIPE_SHADER_TYPES;
988   if (ctx->gfx_stages[PIPE_SHADER_GEOMETRY])
989      ctx->last_vertex_stage = ctx->gfx_stages[PIPE_SHADER_GEOMETRY];
990   else if (ctx->gfx_stages[PIPE_SHADER_TESS_EVAL])
991      ctx->last_vertex_stage = ctx->gfx_stages[PIPE_SHADER_TESS_EVAL];
992   else
993      ctx->last_vertex_stage = ctx->gfx_stages[PIPE_SHADER_VERTEX];
994   enum pipe_shader_type current = ctx->last_vertex_stage ? pipe_shader_type_from_mesa(ctx->last_vertex_stage->nir->info.stage) : PIPE_SHADER_VERTEX;
995   if (old != current) {
996      if (old != PIPE_SHADER_TYPES) {
997         memset(&ctx->gfx_pipeline_state.shader_keys.key[old].key.vs_base, 0, sizeof(struct zink_vs_key_base));
998         ctx->dirty_shader_stages |= BITFIELD_BIT(old);
999      } else {
1000         /* always unset vertex shader values when changing to a non-vs last stage */
1001         memset(&ctx->gfx_pipeline_state.shader_keys.key[PIPE_SHADER_VERTEX].key.vs_base, 0, sizeof(struct zink_vs_key_base));
1002      }
1003
1004      unsigned num_viewports = ctx->vp_state.num_viewports;
1005      struct zink_screen *screen = zink_screen(ctx->base.screen);
1006      /* number of enabled viewports is based on whether last vertex stage writes viewport index */
1007      if (ctx->last_vertex_stage) {
1008         if (ctx->last_vertex_stage->nir->info.outputs_written & (VARYING_BIT_VIEWPORT | VARYING_BIT_VIEWPORT_MASK))
1009            ctx->vp_state.num_viewports = MIN2(screen->info.props.limits.maxViewports, PIPE_MAX_VIEWPORTS);
1010         else
1011            ctx->vp_state.num_viewports = 1;
1012      } else {
1013         ctx->vp_state.num_viewports = 1;
1014      }
1015      ctx->vp_state_changed |= num_viewports != ctx->vp_state.num_viewports;
1016      if (!screen->info.have_EXT_extended_dynamic_state) {
1017         if (ctx->gfx_pipeline_state.dyn_state1.num_viewports != ctx->vp_state.num_viewports)
1018            ctx->gfx_pipeline_state.dirty = true;
1019         ctx->gfx_pipeline_state.dyn_state1.num_viewports = ctx->vp_state.num_viewports;
1020      }
1021      ctx->last_vertex_stage_dirty = true;
1022   }
1023}
1024
1025static void
1026zink_bind_vs_state(struct pipe_context *pctx,
1027                   void *cso)
1028{
1029   struct zink_context *ctx = zink_context(pctx);
1030   if (!cso && !ctx->gfx_stages[PIPE_SHADER_VERTEX])
1031      return;
1032   bind_stage(ctx, PIPE_SHADER_VERTEX, cso);
1033   bind_last_vertex_stage(ctx);
1034   if (cso) {
1035      struct zink_shader *zs = cso;
1036      ctx->shader_reads_drawid = BITSET_TEST(zs->nir->info.system_values_read, SYSTEM_VALUE_DRAW_ID);
1037      ctx->shader_reads_basevertex = BITSET_TEST(zs->nir->info.system_values_read, SYSTEM_VALUE_BASE_VERTEX);
1038   } else {
1039      ctx->shader_reads_drawid = false;
1040      ctx->shader_reads_basevertex = false;
1041   }
1042}
1043
1044/* if gl_SampleMask[] is written to, we have to ensure that we get a shader with the same sample count:
1045 * in GL, samples==1 means ignore gl_SampleMask[]
1046 * in VK, gl_SampleMask[] is never ignored
1047 */
1048void
1049zink_update_fs_key_samples(struct zink_context *ctx)
1050{
1051   if (!ctx->gfx_stages[PIPE_SHADER_FRAGMENT])
1052      return;
1053   nir_shader *nir = ctx->gfx_stages[PIPE_SHADER_FRAGMENT]->nir;
1054   if (nir->info.outputs_written & (1 << FRAG_RESULT_SAMPLE_MASK)) {
1055      bool samples = zink_get_fs_key(ctx)->samples;
1056      if (samples != (ctx->fb_state.samples > 1))
1057         zink_set_fs_key(ctx)->samples = ctx->fb_state.samples > 1;
1058   }
1059}
1060
1061/* if gl_SampleMask[] is written to, we have to ensure that we get a shader with the same sample count:
1062 * in GL, samples==1 means ignore gl_SampleMask[]
1063 * in VK, gl_SampleMask[] is never ignored
1064 */
1065void
1066zink_update_fs_key_samples(struct zink_context *ctx)
1067{
1068   if (!ctx->gfx_stages[PIPE_SHADER_FRAGMENT])
1069      return;
1070   nir_shader *nir = ctx->gfx_stages[PIPE_SHADER_FRAGMENT]->nir;
1071   if (nir->info.outputs_written & (1 << FRAG_RESULT_SAMPLE_MASK)) {
1072      bool samples = zink_get_fs_key(ctx)->samples;
1073      if (samples != (ctx->fb_state.samples > 1))
1074         zink_set_fs_key(ctx)->samples = ctx->fb_state.samples > 1;
1075   }
1076}
1077
1078static void
1079zink_bind_fs_state(struct pipe_context *pctx,
1080                   void *cso)
1081{
1082   struct zink_context *ctx = zink_context(pctx);
1083   if (!cso && !ctx->gfx_stages[PIPE_SHADER_FRAGMENT])
1084      return;
1085   bind_stage(ctx, PIPE_SHADER_FRAGMENT, cso);
1086   ctx->fbfetch_outputs = 0;
1087   if (cso) {
1088      nir_shader *nir = ctx->gfx_stages[PIPE_SHADER_FRAGMENT]->nir;
1089      if (nir->info.fs.uses_fbfetch_output) {
1090         nir_foreach_shader_out_variable(var, ctx->gfx_stages[PIPE_SHADER_FRAGMENT]->nir) {
1091            if (var->data.fb_fetch_output)
1092               ctx->fbfetch_outputs |= BITFIELD_BIT(var->data.location - FRAG_RESULT_DATA0);
1093         }
1094      }
1095      zink_update_fs_key_samples(ctx);
1096   }
1097   zink_update_fbfetch(ctx);
1098}
1099
1100static void
1101zink_bind_gs_state(struct pipe_context *pctx,
1102                   void *cso)
1103{
1104   struct zink_context *ctx = zink_context(pctx);
1105   if (!cso && !ctx->gfx_stages[PIPE_SHADER_GEOMETRY])
1106      return;
1107   bool had_points = ctx->gfx_stages[PIPE_SHADER_GEOMETRY] ? ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->nir->info.gs.output_primitive == SHADER_PRIM_POINTS : false;
1108   bind_stage(ctx, PIPE_SHADER_GEOMETRY, cso);
1109   bind_last_vertex_stage(ctx);
1110   if (cso) {
1111      if (!had_points && ctx->last_vertex_stage->nir->info.gs.output_primitive == SHADER_PRIM_POINTS)
1112         ctx->gfx_pipeline_state.has_points++;
1113   } else {
1114      if (had_points)
1115         ctx->gfx_pipeline_state.has_points--;
1116   }
1117}
1118
1119static void
1120zink_bind_tcs_state(struct pipe_context *pctx,
1121                   void *cso)
1122{
1123   bind_stage(zink_context(pctx), PIPE_SHADER_TESS_CTRL, cso);
1124}
1125
1126static void
1127zink_bind_tes_state(struct pipe_context *pctx,
1128                   void *cso)
1129{
1130   struct zink_context *ctx = zink_context(pctx);
1131   if (!cso && !ctx->gfx_stages[PIPE_SHADER_TESS_EVAL])
1132      return;
1133   if (!!ctx->gfx_stages[PIPE_SHADER_TESS_EVAL] != !!cso) {
1134      if (!cso) {
1135         /* if unsetting a TESS that uses a generated TCS, ensure the TCS is unset */
1136         if (ctx->gfx_stages[PIPE_SHADER_TESS_EVAL]->generated)
1137            ctx->gfx_stages[PIPE_SHADER_TESS_CTRL] = NULL;
1138      }
1139   }
1140   bind_stage(ctx, PIPE_SHADER_TESS_EVAL, cso);
1141   bind_last_vertex_stage(ctx);
1142}
1143
1144static void *
1145zink_create_cs_state(struct pipe_context *pctx,
1146                     const struct pipe_compute_state *shader)
1147{
1148   struct nir_shader *nir;
1149   if (shader->ir_type != PIPE_SHADER_IR_NIR)
1150      nir = zink_tgsi_to_nir(pctx->screen, shader->prog);
1151   else
1152      nir = (struct nir_shader *)shader->prog;
1153
1154   return zink_shader_create(zink_screen(pctx->screen), nir, NULL);
1155}
1156
1157static void
1158zink_bind_cs_state(struct pipe_context *pctx,
1159                   void *cso)
1160{
1161   bind_stage(zink_context(pctx), PIPE_SHADER_COMPUTE, cso);
1162}
1163
1164void
1165zink_delete_shader_state(struct pipe_context *pctx, void *cso)
1166{
1167   zink_shader_free(zink_context(pctx), cso);
1168}
1169
1170void *
1171zink_create_gfx_shader_state(struct pipe_context *pctx, const struct pipe_shader_state *shader)
1172{
1173   nir_shader *nir;
1174   if (shader->type != PIPE_SHADER_IR_NIR)
1175      nir = zink_tgsi_to_nir(pctx->screen, shader->tokens);
1176   else
1177      nir = (struct nir_shader *)shader->ir.nir;
1178
1179   return zink_shader_create(zink_screen(pctx->screen), nir, &shader->stream_output);
1180}
1181
1182static void
1183zink_delete_cached_shader_state(struct pipe_context *pctx, void *cso)
1184{
1185   struct zink_screen *screen = zink_screen(pctx->screen);
1186   util_shader_reference(pctx, &screen->shaders, &cso, NULL);
1187}
1188
1189static void *
1190zink_create_cached_shader_state(struct pipe_context *pctx, const struct pipe_shader_state *shader)
1191{
1192   bool cache_hit;
1193   struct zink_screen *screen = zink_screen(pctx->screen);
1194   return util_live_shader_cache_get(pctx, &screen->shaders, shader, &cache_hit);
1195}
1196
1197void
1198zink_program_init(struct zink_context *ctx)
1199{
1200   ctx->base.create_vs_state = zink_create_cached_shader_state;
1201   ctx->base.bind_vs_state = zink_bind_vs_state;
1202   ctx->base.delete_vs_state = zink_delete_cached_shader_state;
1203
1204   ctx->base.create_fs_state = zink_create_cached_shader_state;
1205   ctx->base.bind_fs_state = zink_bind_fs_state;
1206   ctx->base.delete_fs_state = zink_delete_cached_shader_state;
1207
1208   ctx->base.create_gs_state = zink_create_cached_shader_state;
1209   ctx->base.bind_gs_state = zink_bind_gs_state;
1210   ctx->base.delete_gs_state = zink_delete_cached_shader_state;
1211
1212   ctx->base.create_tcs_state = zink_create_cached_shader_state;
1213   ctx->base.bind_tcs_state = zink_bind_tcs_state;
1214   ctx->base.delete_tcs_state = zink_delete_cached_shader_state;
1215
1216   ctx->base.create_tes_state = zink_create_cached_shader_state;
1217   ctx->base.bind_tes_state = zink_bind_tes_state;
1218   ctx->base.delete_tes_state = zink_delete_cached_shader_state;
1219
1220   ctx->base.create_compute_state = zink_create_cs_state;
1221   ctx->base.bind_compute_state = zink_bind_cs_state;
1222   ctx->base.delete_compute_state = zink_delete_shader_state;
1223}
1224
1225bool
1226zink_set_rasterizer_discard(struct zink_context *ctx, bool disable)
1227{
1228   bool value = disable ? false : (ctx->rast_state ? ctx->rast_state->base.rasterizer_discard : false);
1229   bool changed = ctx->gfx_pipeline_state.dyn_state2.rasterizer_discard != value;
1230   ctx->gfx_pipeline_state.dyn_state2.rasterizer_discard = value;
1231   if (!changed)
1232      return false;
1233   if (!zink_screen(ctx->base.screen)->info.have_EXT_extended_dynamic_state2)
1234      ctx->gfx_pipeline_state.dirty |= true;
1235   ctx->rasterizer_discard_changed = true;
1236   return true;
1237}
1238