1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "d3d12_compiler.h"
27 #include "d3d12_nir_passes.h"
28 #include "dxil_nir.h"
29 #include "program/prog_statevars.h"
30 
31 struct lower_state {
32    nir_variable *uniform; /* (1/w, 1/h, pt_sz, max_sz) */
33    nir_variable *pos_out;
34    nir_variable *psiz_out;
35    nir_variable *point_coord_out[10];
36    unsigned num_point_coords;
37    nir_variable *varying_out[VARYING_SLOT_MAX];
38 
39    nir_ssa_def *point_dir_imm[4];
40    nir_ssa_def *point_coord_imm[4];
41 
42    /* Current point primitive */
43    nir_ssa_def *point_pos;
44    nir_ssa_def *point_size;
45    nir_ssa_def *varying[VARYING_SLOT_MAX];
46    unsigned varying_write_mask[VARYING_SLOT_MAX];
47 
48    bool sprite_origin_lower_left;
49    bool point_size_per_vertex;
50    bool aa_point;
51 };
52 
53 static void
find_outputs(nir_shader *shader, struct lower_state *state)54 find_outputs(nir_shader *shader, struct lower_state *state)
55 {
56    nir_foreach_variable_with_modes(var, shader, nir_var_shader_out) {
57       switch (var->data.location) {
58       case VARYING_SLOT_POS:
59          state->pos_out = var;
60          break;
61       case VARYING_SLOT_PSIZ:
62          state->psiz_out = var;
63          break;
64       default:
65          state->varying_out[var->data.location] = var;
66          break;
67       }
68    }
69 }
70 
71 static nir_ssa_def *
get_point_dir(nir_builder *b, struct lower_state *state, unsigned i)72 get_point_dir(nir_builder *b, struct lower_state *state, unsigned i)
73 {
74    if (state->point_dir_imm[0] == NULL) {
75       state->point_dir_imm[0] = nir_imm_vec2(b, -1, -1);
76       state->point_dir_imm[1] = nir_imm_vec2(b, -1, 1);
77       state->point_dir_imm[2] = nir_imm_vec2(b, 1, -1);
78       state->point_dir_imm[3] = nir_imm_vec2(b, 1, 1);
79    }
80 
81    return state->point_dir_imm[i];
82 }
83 
84 static nir_ssa_def *
get_point_coord(nir_builder *b, struct lower_state *state, unsigned i)85 get_point_coord(nir_builder *b, struct lower_state *state, unsigned i)
86 {
87    if (state->point_coord_imm[0] == NULL) {
88       if (state->sprite_origin_lower_left) {
89          state->point_coord_imm[0] = nir_imm_vec4(b, 0, 0, 0, 1);
90          state->point_coord_imm[1] = nir_imm_vec4(b, 0, 1, 0, 1);
91          state->point_coord_imm[2] = nir_imm_vec4(b, 1, 0, 0, 1);
92          state->point_coord_imm[3] = nir_imm_vec4(b, 1, 1, 0, 1);
93       } else {
94          state->point_coord_imm[0] = nir_imm_vec4(b, 0, 1, 0, 1);
95          state->point_coord_imm[1] = nir_imm_vec4(b, 0, 0, 0, 1);
96          state->point_coord_imm[2] = nir_imm_vec4(b, 1, 1, 0, 1);
97          state->point_coord_imm[3] = nir_imm_vec4(b, 1, 0, 0, 1);
98       }
99    }
100 
101    return state->point_coord_imm[i];
102 }
103 
104 /**
105  * scaled_point_size = pointSize * pos.w * ViewportSizeRcp
106  */
107 static void
get_scaled_point_size(nir_builder *b, struct lower_state *state, nir_ssa_def **x, nir_ssa_def **y)108 get_scaled_point_size(nir_builder *b, struct lower_state *state,
109                       nir_ssa_def **x, nir_ssa_def **y)
110 {
111    /* State uniform contains: (1/ViewportWidth, 1/ViewportHeight, PointSize, MaxPointSize) */
112    nir_ssa_def *uniform = nir_load_var(b, state->uniform);
113    nir_ssa_def *point_size = state->point_size;
114 
115    /* clamp point-size to valid range */
116    if (point_size && state->point_size_per_vertex) {
117       point_size = nir_fmax(b, point_size, nir_imm_float(b, 1.0f));
118       point_size = nir_fmin(b, point_size, nir_imm_float(b, D3D12_MAX_POINT_SIZE));
119    } else {
120       /* Use static point size (from uniform) if the shader output was not set */
121       point_size = nir_channel(b, uniform, 2);
122    }
123 
124    point_size = nir_fmul(b, point_size, nir_channel(b, state->point_pos, 3));
125    *x = nir_fmul(b, point_size, nir_channel(b, uniform, 0));
126    *y = nir_fmul(b, point_size, nir_channel(b, uniform, 1));
127 }
128 
129 static bool
lower_store(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)130 lower_store(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)
131 {
132    nir_deref_instr *deref = nir_src_as_deref(instr->src[0]);
133    if (nir_deref_mode_is(deref, nir_var_shader_out)) {
134       nir_variable *var = nir_deref_instr_get_variable(deref);
135 
136       switch (var->data.location) {
137       case VARYING_SLOT_POS:
138          state->point_pos = instr->src[1].ssa;
139          break;
140       case VARYING_SLOT_PSIZ:
141          state->point_size = instr->src[1].ssa;
142          break;
143       default:
144          state->varying[var->data.location] = instr->src[1].ssa;
145          state->varying_write_mask[var->data.location] = nir_intrinsic_write_mask(instr);
146          break;
147       }
148 
149       nir_instr_remove(&instr->instr);
150       return true;
151    }
152 
153    return false;
154 }
155 
156 static bool
lower_emit_vertex(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)157 lower_emit_vertex(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)
158 {
159    unsigned stream_id = nir_intrinsic_stream_id(instr);
160 
161    nir_ssa_def *point_width, *point_height;
162    get_scaled_point_size(b, state, &point_width, &point_height);
163 
164    nir_instr_remove(&instr->instr);
165    if (stream_id == 0) {
166       for (unsigned i = 0; i < 4; i++) {
167          /* All outputs need to be emitted for each vertex */
168          for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
169             if (state->varying[slot] != NULL) {
170                nir_store_var(b, state->varying_out[slot], state->varying[slot],
171                              state->varying_write_mask[slot]);
172             }
173          }
174 
175          /* pos = scaled_point_size * point_dir + point_pos */
176          nir_ssa_def *point_dir = get_point_dir(b, state, i);
177          nir_ssa_def *pos = nir_vec4(b,
178                                      nir_ffma(b,
179                                               point_width,
180                                               nir_channel(b, point_dir, 0),
181                                               nir_channel(b, state->point_pos, 0)),
182                                      nir_ffma(b,
183                                               point_height,
184                                               nir_channel(b, point_dir, 1),
185                                               nir_channel(b, state->point_pos, 1)),
186                                      nir_channel(b, state->point_pos, 2),
187                                      nir_channel(b, state->point_pos, 3));
188          nir_store_var(b, state->pos_out, pos, 0xf);
189 
190          /* point coord */
191          nir_ssa_def *point_coord = get_point_coord(b, state, i);
192          for (unsigned j = 0; j < state->num_point_coords; ++j) {
193             unsigned num_channels = glsl_get_components(state->point_coord_out[j]->type);
194             unsigned mask = (1 << num_channels) - 1;
195             nir_store_var(b, state->point_coord_out[j], nir_channels(b, point_coord, mask), mask);
196          }
197 
198          /* EmitVertex */
199          nir_emit_vertex(b, .stream_id = stream_id);
200       }
201 
202       /* EndPrimitive */
203       nir_end_primitive(b, .stream_id = stream_id);
204    }
205 
206    /* Reset everything */
207    state->point_pos = NULL;
208    state->point_size = NULL;
209    for (unsigned i = 0; i < VARYING_SLOT_MAX; ++i)
210       state->varying[i] = NULL;
211 
212    return true;
213 }
214 
215 static bool
lower_instr(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)216 lower_instr(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)
217 {
218    b->cursor = nir_before_instr(&instr->instr);
219 
220    if (instr->intrinsic == nir_intrinsic_store_deref) {
221       return lower_store(instr, b, state);
222    } else if (instr->intrinsic == nir_intrinsic_emit_vertex) {
223       return lower_emit_vertex(instr, b, state);
224    } else if (instr->intrinsic == nir_intrinsic_end_primitive) {
225       nir_instr_remove(&instr->instr);
226       return true;
227    }
228 
229    return false;
230 }
231 
232 bool
d3d12_lower_point_sprite(nir_shader *shader, bool sprite_origin_lower_left, bool point_size_per_vertex, unsigned point_coord_enable, uint64_t next_inputs_read)233 d3d12_lower_point_sprite(nir_shader *shader,
234                          bool sprite_origin_lower_left,
235                          bool point_size_per_vertex,
236                          unsigned point_coord_enable,
237                          uint64_t next_inputs_read)
238 {
239    const gl_state_index16 tokens[4] = { STATE_INTERNAL_DRIVER,
240                                         D3D12_STATE_VAR_PT_SPRITE };
241    struct lower_state state;
242    bool progress = false;
243 
244    assert(shader->info.gs.output_primitive == GL_POINTS);
245 
246    memset(&state, 0, sizeof(state));
247    find_outputs(shader, &state);
248    state.sprite_origin_lower_left = sprite_origin_lower_left;
249    state.point_size_per_vertex = point_size_per_vertex;
250 
251    /* Create uniform to retrieve inverse of viewport size and point size:
252     * (1/ViewportWidth, 1/ViewportHeight, PointSize, MaxPointSize) */
253    state.uniform = nir_variable_create(shader,
254                                        nir_var_uniform,
255                                        glsl_vec4_type(),
256                                        "d3d12_ViewportSizeRcp");
257    state.uniform->num_state_slots = 1;
258    state.uniform->state_slots = ralloc_array(state.uniform, nir_state_slot, 1);
259    memcpy(state.uniform->state_slots[0].tokens, tokens,
260           sizeof(state.uniform->state_slots[0].tokens));
261    shader->num_uniforms++;
262 
263    /* Create new outputs for point tex coordinates */
264    unsigned count = 0;
265    for (unsigned int sem = 0; sem < ARRAY_SIZE(state.point_coord_out); sem++) {
266       if (point_coord_enable & BITFIELD64_BIT(sem)) {
267          char tmp[100];
268          unsigned location = VARYING_SLOT_TEX0 + sem;
269 
270          snprintf(tmp, ARRAY_SIZE(tmp), "gl_TexCoord%dMESA", count);
271 
272          nir_variable *var = nir_variable_create(shader,
273                                                  nir_var_shader_out,
274                                                  glsl_vec4_type(),
275                                                  tmp);
276          var->data.location = location;
277          state.point_coord_out[count++] = var;
278       }
279    }
280    if (next_inputs_read & VARYING_BIT_PNTC) {
281       nir_variable *pntcoord_var = nir_variable_create(shader,
282                                                        nir_var_shader_out,
283                                                        glsl_vec_type(2),
284                                                        "gl_PointCoordMESA");
285       pntcoord_var->data.location = VARYING_SLOT_PNTC;
286       state.point_coord_out[count++] = pntcoord_var;
287    }
288 
289    state.num_point_coords = count;
290    if (count) {
291       dxil_reassign_driver_locations(shader, nir_var_shader_out,
292                                      next_inputs_read);
293    }
294 
295    nir_foreach_function(function, shader) {
296       if (function->impl) {
297          nir_builder builder;
298          nir_builder_init(&builder, function->impl);
299          nir_foreach_block(block, function->impl) {
300             nir_foreach_instr_safe(instr, block) {
301                if (instr->type == nir_instr_type_intrinsic)
302                   progress |= lower_instr(nir_instr_as_intrinsic(instr),
303                                           &builder,
304                                           &state);
305             }
306          }
307 
308          nir_metadata_preserve(function->impl, nir_metadata_block_index |
309                                                nir_metadata_dominance);
310       }
311    }
312 
313    shader->info.gs.output_primitive = GL_TRIANGLE_STRIP;
314    shader->info.gs.vertices_out = shader->info.gs.vertices_out * 4 /
315       util_bitcount(shader->info.gs.active_stream_mask);
316    shader->info.gs.active_stream_mask = 1;
317 
318    return progress;
319 }
320