1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "nir.h"
25#include "nir_builder.h"
26#include "d3d12_context.h"
27#include "d3d12_compiler.h"
28#include "d3d12_nir_passes.h"
29#include "d3d12_screen.h"
30
31static uint32_t
32hash_tcs_variant_key(const void *key)
33{
34   return _mesa_hash_data(key, sizeof(struct d3d12_tcs_variant_key));
35}
36
37static bool
38equals_tcs_variant_key(const void *a, const void *b)
39{
40   return memcmp(a, b, sizeof(struct d3d12_tcs_variant_key)) == 0;
41}
42
43void
44d3d12_tcs_variant_cache_init(struct d3d12_context *ctx)
45{
46   ctx->tcs_variant_cache = _mesa_hash_table_create(NULL, NULL, equals_tcs_variant_key);
47}
48
49static void
50delete_entry(struct hash_entry *entry)
51{
52   d3d12_shader_free((d3d12_shader_selector *)entry->data);
53}
54
55void
56d3d12_tcs_variant_cache_destroy(struct d3d12_context *ctx)
57{
58   _mesa_hash_table_destroy(ctx->tcs_variant_cache, delete_entry);
59}
60
61static void
62copy_vars(nir_builder *b, nir_deref_instr *dst, nir_deref_instr *src)
63{
64   assert(glsl_get_bare_type(dst->type) == glsl_get_bare_type(src->type));
65   if (glsl_type_is_struct(dst->type)) {
66      for (unsigned i = 0; i < glsl_get_length(dst->type); ++i) {
67         copy_vars(b, nir_build_deref_struct(b, dst, i), nir_build_deref_struct(b, src, i));
68      }
69   } else if (glsl_type_is_array_or_matrix(dst->type)) {
70      copy_vars(b, nir_build_deref_array_wildcard(b, dst), nir_build_deref_array_wildcard(b, src));
71   } else {
72      nir_copy_deref(b, dst, src);
73   }
74}
75
76static struct d3d12_shader_selector *
77create_tess_ctrl_shader_variant(struct d3d12_context *ctx, struct d3d12_tcs_variant_key *key)
78{
79   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_TESS_CTRL, &d3d12_screen(ctx->base.screen)->nir_options, "passthrough");
80   nir_shader *nir = b.shader;
81
82   nir_ssa_def *invocation_id = nir_load_invocation_id(&b);
83   uint64_t varying_mask = key->varyings.mask;
84
85   while(varying_mask) {
86      int var_idx = u_bit_scan64(&varying_mask);
87      auto slot = &key->varyings.slots[var_idx];
88      unsigned frac_mask = slot->location_frac_mask;
89      while (frac_mask) {
90         int frac = u_bit_scan(&frac_mask);
91         auto var = &slot->vars[frac];
92         const struct glsl_type *type = glsl_array_type(slot->types[frac], key->vertices_out, 0);
93
94         char buf[1024];
95         snprintf(buf, sizeof(buf), "in_%d", var->driver_location);
96         nir_variable *in = nir_variable_create(nir, nir_var_shader_in, type, buf);
97         snprintf(buf, sizeof(buf), "out_%d", var->driver_location);
98         nir_variable *out = nir_variable_create(nir, nir_var_shader_out, type, buf);
99         out->data.location = in->data.location = var_idx;
100         out->data.location_frac = in->data.location_frac = frac;
101         out->data.driver_location = in->data.driver_location = var->driver_location;
102
103         for (unsigned i = 0; i < key->vertices_out; i++) {
104            nir_if *start_block = nir_push_if(&b, nir_ieq(&b, invocation_id, nir_imm_int(&b, i)));
105            nir_deref_instr *in_array_var = nir_build_deref_array(&b, nir_build_deref_var(&b, in), invocation_id);
106            nir_deref_instr *out_array_var = nir_build_deref_array_imm(&b, nir_build_deref_var(&b, out), i);
107            copy_vars(&b, out_array_var, in_array_var);
108            nir_pop_if(&b, start_block);
109         }
110      }
111   }
112   nir_variable *gl_TessLevelInner = nir_variable_create(nir, nir_var_shader_out, glsl_array_type(glsl_float_type(), 2, 0), "gl_TessLevelInner");
113   gl_TessLevelInner->data.location = VARYING_SLOT_TESS_LEVEL_INNER;
114   gl_TessLevelInner->data.patch = 1;
115   gl_TessLevelInner->data.compact = 1;
116   nir_variable *gl_TessLevelOuter = nir_variable_create(nir, nir_var_shader_out, glsl_array_type(glsl_float_type(), 4, 0), "gl_TessLevelOuter");
117   gl_TessLevelOuter->data.location = VARYING_SLOT_TESS_LEVEL_OUTER;
118   gl_TessLevelOuter->data.patch = 1;
119   gl_TessLevelOuter->data.compact = 1;
120
121   nir_variable *state_var_inner = NULL, *state_var_outer = NULL;
122   nir_ssa_def *load_inner = d3d12_get_state_var(&b, D3D12_STATE_VAR_DEFAULT_INNER_TESS_LEVEL, "d3d12_TessLevelInner", glsl_vec_type(2), &state_var_inner);
123   nir_ssa_def *load_outer = d3d12_get_state_var(&b, D3D12_STATE_VAR_DEFAULT_OUTER_TESS_LEVEL, "d3d12_TessLevelOuter", glsl_vec4_type(), &state_var_outer);
124
125   for (unsigned i = 0; i < 2; i++) {
126      nir_deref_instr *store_idx = nir_build_deref_array_imm(&b, nir_build_deref_var(&b, gl_TessLevelInner), i);
127      nir_store_deref(&b, store_idx, nir_channel(&b, load_inner, i), 0xff);
128   }
129   for (unsigned i = 0; i < 4; i++) {
130      nir_deref_instr *store_idx = nir_build_deref_array_imm(&b, nir_build_deref_var(&b, gl_TessLevelOuter), i);
131      nir_store_deref(&b, store_idx, nir_channel(&b, load_outer, i), 0xff);
132   }
133
134   nir->info.tess.tcs_vertices_out = key->vertices_out;
135   nir_validate_shader(nir, "created");
136   NIR_PASS_V(nir, nir_lower_var_copies);
137
138   struct pipe_shader_state templ;
139
140   templ.type = PIPE_SHADER_IR_NIR;
141   templ.ir.nir = nir;
142   templ.stream_output.num_outputs = 0;
143
144   d3d12_shader_selector *tcs = d3d12_create_shader(ctx, PIPE_SHADER_TESS_CTRL, &templ);
145   if (tcs) {
146      tcs->is_variant = true;
147      memcpy(&tcs->tcs_key, key, sizeof(*key));
148   }
149   return tcs;
150}
151
152d3d12_shader_selector *
153d3d12_get_tcs_variant(struct d3d12_context *ctx, struct d3d12_tcs_variant_key *key)
154{
155   uint32_t hash = hash_tcs_variant_key(key);
156   struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->tcs_variant_cache,
157      hash, key);
158   if (!entry) {
159      d3d12_shader_selector *tcs = create_tess_ctrl_shader_variant(ctx, key);
160      entry = _mesa_hash_table_insert_pre_hashed(ctx->tcs_variant_cache,
161         hash, &tcs->tcs_key, tcs);
162      assert(entry);
163   }
164
165   return (d3d12_shader_selector *)entry->data;
166}
167