1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "d3d12_pipeline_state.h"
25#include "d3d12_compiler.h"
26#include "d3d12_context.h"
27#include "d3d12_screen.h"
28
29#include "util/hash_table.h"
30#include "util/set.h"
31#include "util/u_memory.h"
32#include "util/u_prim.h"
33
34#include <dxguids/dxguids.h>
35
36struct d3d12_gfx_pso_entry {
37   struct d3d12_gfx_pipeline_state key;
38   ID3D12PipelineState *pso;
39};
40
41struct d3d12_compute_pso_entry {
42   struct d3d12_compute_pipeline_state key;
43   ID3D12PipelineState *pso;
44};
45
46static const char *
47get_semantic_name(int location, int driver_location, unsigned *index)
48{
49   *index = 0; /* Default index */
50
51   switch (location) {
52
53   case VARYING_SLOT_POS:
54      return "SV_Position";
55
56    case VARYING_SLOT_FACE:
57      return "SV_IsFrontFace";
58
59   case VARYING_SLOT_CLIP_DIST1:
60      *index = 1;
61      FALLTHROUGH;
62   case VARYING_SLOT_CLIP_DIST0:
63      return "SV_ClipDistance";
64
65   case VARYING_SLOT_PRIMITIVE_ID:
66      return "SV_PrimitiveID";
67
68   case VARYING_SLOT_VIEWPORT:
69      return "SV_ViewportArrayIndex";
70
71   case VARYING_SLOT_LAYER:
72      return "SV_RenderTargetArrayIndex";
73
74   default: {
75         *index = driver_location;
76         return "TEXCOORD";
77      }
78   }
79}
80
81static nir_variable *
82find_so_variable(nir_shader *s, int location, unsigned location_frac, unsigned num_components)
83{
84   nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
85      if (var->data.location != location || var->data.location_frac > location_frac)
86         continue;
87      unsigned var_num_components = var->data.compact ?
88         glsl_get_length(var->type) : glsl_get_components(var->type);
89      if (var->data.location_frac <= location_frac &&
90          var->data.location_frac + var_num_components >= location_frac + num_components)
91         return var;
92   }
93   return nullptr;
94}
95
96static void
97fill_so_declaration(const struct pipe_stream_output_info *info,
98                    nir_shader *last_vertex_stage,
99                    D3D12_SO_DECLARATION_ENTRY *entries, UINT *num_entries,
100                    UINT *strides, UINT *num_strides)
101{
102   int next_offset[MAX_VERTEX_STREAMS] = { 0 };
103
104   *num_entries = 0;
105
106   for (unsigned i = 0; i < info->num_outputs; i++) {
107      const struct pipe_stream_output *output = &info->output[i];
108      const int buffer = output->output_buffer;
109      unsigned index;
110
111      /* Mesa doesn't store entries for gl_SkipComponents in the Outputs[]
112       * array.  Instead, it simply increments DstOffset for the following
113       * input by the number of components that should be skipped.
114       *
115       * DirectX12 requires that we create gap entries.
116       */
117      int skip_components = output->dst_offset - next_offset[buffer];
118
119      if (skip_components > 0) {
120         entries[*num_entries].Stream = output->stream;
121         entries[*num_entries].SemanticName = NULL;
122         entries[*num_entries].ComponentCount = skip_components;
123         entries[*num_entries].OutputSlot = buffer;
124         (*num_entries)++;
125      }
126
127      next_offset[buffer] = output->dst_offset + output->num_components;
128
129      entries[*num_entries].Stream = output->stream;
130      nir_variable *var = find_so_variable(last_vertex_stage,
131         output->register_index, output->start_component, output->num_components);
132      assert((var->data.stream & ~NIR_STREAM_PACKED) == output->stream);
133      entries[*num_entries].SemanticName = get_semantic_name(var->data.location,
134         var->data.driver_location, &index);
135      entries[*num_entries].SemanticIndex = index;
136      entries[*num_entries].StartComponent = output->start_component - var->data.location_frac;
137      entries[*num_entries].ComponentCount = output->num_components;
138      entries[*num_entries].OutputSlot = buffer;
139      (*num_entries)++;
140   }
141
142   for (unsigned i = 0; i < MAX_VERTEX_STREAMS; i++)
143      strides[i] = info->stride[i] * 4;
144   *num_strides = MAX_VERTEX_STREAMS;
145}
146
147static bool
148depth_bias(struct d3d12_rasterizer_state *state, enum pipe_prim_type reduced_prim)
149{
150   /* glPolygonOffset is supposed to be only enabled when rendering polygons.
151    * In d3d12 case, all polygons (and quads) are lowered to triangles */
152   if (reduced_prim != PIPE_PRIM_TRIANGLES)
153      return false;
154
155   unsigned fill_mode = state->base.cull_face == PIPE_FACE_FRONT ? state->base.fill_back
156                                                                 : state->base.fill_front;
157
158   switch (fill_mode) {
159   case PIPE_POLYGON_MODE_FILL:
160      return state->base.offset_tri;
161
162   case PIPE_POLYGON_MODE_LINE:
163      return state->base.offset_line;
164
165   case PIPE_POLYGON_MODE_POINT:
166      return state->base.offset_point;
167
168   default:
169      unreachable("unexpected fill mode");
170   }
171}
172
173static D3D12_PRIMITIVE_TOPOLOGY_TYPE
174topology_type(enum pipe_prim_type reduced_prim)
175{
176   switch (reduced_prim) {
177   case PIPE_PRIM_POINTS:
178      return D3D12_PRIMITIVE_TOPOLOGY_TYPE_POINT;
179
180   case PIPE_PRIM_LINES:
181      return D3D12_PRIMITIVE_TOPOLOGY_TYPE_LINE;
182
183   case PIPE_PRIM_TRIANGLES:
184      return D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
185
186   case PIPE_PRIM_PATCHES:
187      return D3D12_PRIMITIVE_TOPOLOGY_TYPE_PATCH;
188
189   default:
190      debug_printf("pipe_prim_type: %s\n", u_prim_name(reduced_prim));
191      unreachable("unexpected enum pipe_prim_type");
192   }
193}
194
195DXGI_FORMAT
196d3d12_rtv_format(struct d3d12_context *ctx, unsigned index)
197{
198   DXGI_FORMAT fmt = ctx->gfx_pipeline_state.rtv_formats[index];
199
200   if (ctx->gfx_pipeline_state.blend->desc.RenderTarget[0].LogicOpEnable &&
201       !ctx->gfx_pipeline_state.has_float_rtv) {
202      switch (fmt) {
203      case DXGI_FORMAT_R8G8B8A8_SNORM:
204      case DXGI_FORMAT_R8G8B8A8_UNORM:
205      case DXGI_FORMAT_B8G8R8A8_UNORM:
206      case DXGI_FORMAT_B8G8R8X8_UNORM:
207         return DXGI_FORMAT_R8G8B8A8_UINT;
208      default:
209         unreachable("unsupported logic-op format");
210      }
211   }
212
213   return fmt;
214}
215
216static ID3D12PipelineState *
217create_gfx_pipeline_state(struct d3d12_context *ctx)
218{
219   struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
220   struct d3d12_gfx_pipeline_state *state = &ctx->gfx_pipeline_state;
221   enum pipe_prim_type reduced_prim = state->prim_type == PIPE_PRIM_PATCHES ?
222      PIPE_PRIM_PATCHES : u_reduced_prim(state->prim_type);
223   D3D12_SO_DECLARATION_ENTRY entries[PIPE_MAX_SO_OUTPUTS] = {};
224   UINT strides[PIPE_MAX_SO_OUTPUTS] = { 0 };
225   UINT num_entries = 0, num_strides = 0;
226
227   D3D12_GRAPHICS_PIPELINE_STATE_DESC pso_desc = { 0 };
228   pso_desc.pRootSignature = state->root_signature;
229
230   nir_shader *last_vertex_stage_nir = NULL;
231
232   if (state->stages[PIPE_SHADER_VERTEX]) {
233      auto shader = state->stages[PIPE_SHADER_VERTEX];
234      pso_desc.VS.BytecodeLength = shader->bytecode_length;
235      pso_desc.VS.pShaderBytecode = shader->bytecode;
236      last_vertex_stage_nir = shader->nir;
237   }
238
239   if (state->stages[PIPE_SHADER_TESS_CTRL]) {
240      auto shader = state->stages[PIPE_SHADER_TESS_CTRL];
241      pso_desc.HS.BytecodeLength = shader->bytecode_length;
242      pso_desc.HS.pShaderBytecode = shader->bytecode;
243      last_vertex_stage_nir = shader->nir;
244   }
245
246   if (state->stages[PIPE_SHADER_TESS_EVAL]) {
247      auto shader = state->stages[PIPE_SHADER_TESS_EVAL];
248      pso_desc.DS.BytecodeLength = shader->bytecode_length;
249      pso_desc.DS.pShaderBytecode = shader->bytecode;
250      last_vertex_stage_nir = shader->nir;
251   }
252
253   if (state->stages[PIPE_SHADER_GEOMETRY]) {
254      auto shader = state->stages[PIPE_SHADER_GEOMETRY];
255      pso_desc.GS.BytecodeLength = shader->bytecode_length;
256      pso_desc.GS.pShaderBytecode = shader->bytecode;
257      last_vertex_stage_nir = shader->nir;
258   }
259
260   bool last_vertex_stage_writes_pos = (last_vertex_stage_nir->info.outputs_written & VARYING_BIT_POS) != 0;
261   if (last_vertex_stage_writes_pos && state->stages[PIPE_SHADER_FRAGMENT] &&
262       !state->rast->base.rasterizer_discard) {
263      auto shader = state->stages[PIPE_SHADER_FRAGMENT];
264      pso_desc.PS.BytecodeLength = shader->bytecode_length;
265      pso_desc.PS.pShaderBytecode = shader->bytecode;
266   }
267
268   if (state->num_so_targets)
269      fill_so_declaration(&state->so_info, last_vertex_stage_nir, entries, &num_entries, strides, &num_strides);
270   pso_desc.StreamOutput.NumEntries = num_entries;
271   pso_desc.StreamOutput.pSODeclaration = entries;
272   pso_desc.StreamOutput.RasterizedStream = state->rast->base.rasterizer_discard ? D3D12_SO_NO_RASTERIZED_STREAM : 0;
273   pso_desc.StreamOutput.NumStrides = num_strides;
274   pso_desc.StreamOutput.pBufferStrides = strides;
275
276   pso_desc.BlendState = state->blend->desc;
277   if (state->has_float_rtv)
278      pso_desc.BlendState.RenderTarget[0].LogicOpEnable = FALSE;
279
280   pso_desc.DepthStencilState = state->zsa->desc;
281   pso_desc.SampleMask = state->sample_mask;
282   pso_desc.RasterizerState = state->rast->desc;
283
284   if (reduced_prim != PIPE_PRIM_TRIANGLES)
285      pso_desc.RasterizerState.CullMode = D3D12_CULL_MODE_NONE;
286
287   if (depth_bias(state->rast, reduced_prim)) {
288      pso_desc.RasterizerState.DepthBias = state->rast->base.offset_units * 2;
289      pso_desc.RasterizerState.DepthBiasClamp = state->rast->base.offset_clamp;
290      pso_desc.RasterizerState.SlopeScaledDepthBias = state->rast->base.offset_scale;
291   }
292
293   pso_desc.InputLayout.pInputElementDescs = state->ves->elements;
294   pso_desc.InputLayout.NumElements = state->ves->num_elements;
295
296   pso_desc.IBStripCutValue = state->ib_strip_cut_value;
297
298   pso_desc.PrimitiveTopologyType = topology_type(reduced_prim);
299
300   pso_desc.NumRenderTargets = state->num_cbufs;
301   for (unsigned i = 0; i < state->num_cbufs; ++i)
302      pso_desc.RTVFormats[i] = d3d12_rtv_format(ctx, i);
303   pso_desc.DSVFormat = state->dsv_format;
304
305   if (state->num_cbufs || state->dsv_format != DXGI_FORMAT_UNKNOWN) {
306      pso_desc.SampleDesc.Count = state->samples;
307      if (!state->zsa->desc.DepthEnable &&
308          !state->zsa->desc.StencilEnable &&
309          !state->rast->desc.MultisampleEnable &&
310          state->samples > 1) {
311         pso_desc.RasterizerState.ForcedSampleCount = 1;
312         pso_desc.DSVFormat = DXGI_FORMAT_UNKNOWN;
313      }
314   } else if (state->samples > 1) {
315      pso_desc.SampleDesc.Count = 1;
316      pso_desc.RasterizerState.ForcedSampleCount = state->samples;
317   }
318   pso_desc.SampleDesc.Quality = 0;
319
320   pso_desc.NodeMask = 0;
321
322   pso_desc.CachedPSO.pCachedBlob = NULL;
323   pso_desc.CachedPSO.CachedBlobSizeInBytes = 0;
324
325   pso_desc.Flags = D3D12_PIPELINE_STATE_FLAG_NONE;
326
327   ID3D12PipelineState *ret;
328   if (FAILED(screen->dev->CreateGraphicsPipelineState(&pso_desc,
329                                                       IID_PPV_ARGS(&ret)))) {
330      debug_printf("D3D12: CreateGraphicsPipelineState failed!\n");
331      return NULL;
332   }
333
334   return ret;
335}
336
337static uint32_t
338hash_gfx_pipeline_state(const void *key)
339{
340   return _mesa_hash_data(key, sizeof(struct d3d12_gfx_pipeline_state));
341}
342
343static bool
344equals_gfx_pipeline_state(const void *a, const void *b)
345{
346   return memcmp(a, b, sizeof(struct d3d12_gfx_pipeline_state)) == 0;
347}
348
349ID3D12PipelineState *
350d3d12_get_gfx_pipeline_state(struct d3d12_context *ctx)
351{
352   uint32_t hash = hash_gfx_pipeline_state(&ctx->gfx_pipeline_state);
353   struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->pso_cache, hash,
354                                                                 &ctx->gfx_pipeline_state);
355   if (!entry) {
356      struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)MALLOC(sizeof(struct d3d12_gfx_pso_entry));
357      if (!data)
358         return NULL;
359
360      data->key = ctx->gfx_pipeline_state;
361      data->pso = create_gfx_pipeline_state(ctx);
362      if (!data->pso) {
363         FREE(data);
364         return NULL;
365      }
366
367      entry = _mesa_hash_table_insert_pre_hashed(ctx->pso_cache, hash, &data->key, data);
368      assert(entry);
369   }
370
371   return ((struct d3d12_gfx_pso_entry *)(entry->data))->pso;
372}
373
374void
375d3d12_gfx_pipeline_state_cache_init(struct d3d12_context *ctx)
376{
377   ctx->pso_cache = _mesa_hash_table_create(NULL, NULL, equals_gfx_pipeline_state);
378}
379
380static void
381delete_gfx_entry(struct hash_entry *entry)
382{
383   struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)entry->data;
384   data->pso->Release();
385   FREE(data);
386}
387
388static void
389remove_gfx_entry(struct d3d12_context *ctx, struct hash_entry *entry)
390{
391   struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)entry->data;
392
393   if (ctx->current_gfx_pso == data->pso)
394      ctx->current_gfx_pso = NULL;
395   _mesa_hash_table_remove(ctx->pso_cache, entry);
396   delete_gfx_entry(entry);
397}
398
399void
400d3d12_gfx_pipeline_state_cache_destroy(struct d3d12_context *ctx)
401{
402   _mesa_hash_table_destroy(ctx->pso_cache, delete_gfx_entry);
403}
404
405void
406d3d12_gfx_pipeline_state_cache_invalidate(struct d3d12_context *ctx, const void *state)
407{
408   hash_table_foreach(ctx->pso_cache, entry) {
409      const struct d3d12_gfx_pipeline_state *key = (struct d3d12_gfx_pipeline_state *)entry->key;
410      if (key->blend == state || key->zsa == state || key->rast == state)
411         remove_gfx_entry(ctx, entry);
412   }
413}
414
415void
416d3d12_gfx_pipeline_state_cache_invalidate_shader(struct d3d12_context *ctx,
417                                                 enum pipe_shader_type stage,
418                                                 struct d3d12_shader_selector *selector)
419{
420   struct d3d12_shader *shader = selector->first;
421
422   while (shader) {
423      hash_table_foreach(ctx->pso_cache, entry) {
424         const struct d3d12_gfx_pipeline_state *key = (struct d3d12_gfx_pipeline_state *)entry->key;
425         if (key->stages[stage] == shader)
426            remove_gfx_entry(ctx, entry);
427      }
428      shader = shader->next_variant;
429   }
430}
431
432static ID3D12PipelineState *
433create_compute_pipeline_state(struct d3d12_context *ctx)
434{
435   struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
436   struct d3d12_compute_pipeline_state *state = &ctx->compute_pipeline_state;
437
438   D3D12_COMPUTE_PIPELINE_STATE_DESC pso_desc = { 0 };
439   pso_desc.pRootSignature = state->root_signature;
440
441   if (state->stage) {
442      auto shader = state->stage;
443      pso_desc.CS.BytecodeLength = shader->bytecode_length;
444      pso_desc.CS.pShaderBytecode = shader->bytecode;
445   }
446
447   pso_desc.NodeMask = 0;
448
449   pso_desc.CachedPSO.pCachedBlob = NULL;
450   pso_desc.CachedPSO.CachedBlobSizeInBytes = 0;
451
452   pso_desc.Flags = D3D12_PIPELINE_STATE_FLAG_NONE;
453
454   ID3D12PipelineState *ret;
455   if (FAILED(screen->dev->CreateComputePipelineState(&pso_desc,
456                                                      IID_PPV_ARGS(&ret)))) {
457      debug_printf("D3D12: CreateComputePipelineState failed!\n");
458      return NULL;
459   }
460
461   return ret;
462}
463
464static uint32_t
465hash_compute_pipeline_state(const void *key)
466{
467   return _mesa_hash_data(key, sizeof(struct d3d12_compute_pipeline_state));
468}
469
470static bool
471equals_compute_pipeline_state(const void *a, const void *b)
472{
473   return memcmp(a, b, sizeof(struct d3d12_compute_pipeline_state)) == 0;
474}
475
476ID3D12PipelineState *
477d3d12_get_compute_pipeline_state(struct d3d12_context *ctx)
478{
479   uint32_t hash = hash_compute_pipeline_state(&ctx->compute_pipeline_state);
480   struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->compute_pso_cache, hash,
481                                                                 &ctx->compute_pipeline_state);
482   if (!entry) {
483      struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)MALLOC(sizeof(struct d3d12_compute_pso_entry));
484      if (!data)
485         return NULL;
486
487      data->key = ctx->compute_pipeline_state;
488      data->pso = create_compute_pipeline_state(ctx);
489      if (!data->pso) {
490         FREE(data);
491         return NULL;
492      }
493
494      entry = _mesa_hash_table_insert_pre_hashed(ctx->compute_pso_cache, hash, &data->key, data);
495      assert(entry);
496   }
497
498   return ((struct d3d12_compute_pso_entry *)(entry->data))->pso;
499}
500
501void
502d3d12_compute_pipeline_state_cache_init(struct d3d12_context *ctx)
503{
504   ctx->compute_pso_cache = _mesa_hash_table_create(NULL, NULL, equals_compute_pipeline_state);
505}
506
507static void
508delete_compute_entry(struct hash_entry *entry)
509{
510   struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)entry->data;
511   data->pso->Release();
512   FREE(data);
513}
514
515static void
516remove_compute_entry(struct d3d12_context *ctx, struct hash_entry *entry)
517{
518   struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)entry->data;
519
520   if (ctx->current_compute_pso == data->pso)
521      ctx->current_compute_pso = NULL;
522   _mesa_hash_table_remove(ctx->compute_pso_cache, entry);
523   delete_compute_entry(entry);
524}
525
526void
527d3d12_compute_pipeline_state_cache_destroy(struct d3d12_context *ctx)
528{
529   _mesa_hash_table_destroy(ctx->compute_pso_cache, delete_compute_entry);
530}
531
532void
533d3d12_compute_pipeline_state_cache_invalidate_shader(struct d3d12_context *ctx,
534                                                     struct d3d12_shader_selector *selector)
535{
536   struct d3d12_shader *shader = selector->first;
537
538   while (shader) {
539      hash_table_foreach(ctx->compute_pso_cache, entry) {
540         const struct d3d12_compute_pipeline_state *key = (struct d3d12_compute_pipeline_state *)entry->key;
541         if (key->stage == shader)
542            remove_compute_entry(ctx, entry);
543      }
544      shader = shader->next_variant;
545   }
546}
547