xref: /third_party/mesa3d/src/microsoft/clc/clc_nir.c (revision bf215546)
1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "u_math.h"
25#include "nir.h"
26#include "glsl_types.h"
27#include "nir_types.h"
28#include "nir_builder.h"
29
30#include "clc_nir.h"
31#include "clc_compiler.h"
32#include "../compiler/dxil_nir.h"
33
34static bool
35lower_load_base_global_invocation_id(nir_builder *b, nir_intrinsic_instr *intr,
36                                    nir_variable *var)
37{
38   b->cursor = nir_after_instr(&intr->instr);
39
40   nir_ssa_def *offset =
41      build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
42                          nir_imm_int(b,
43                                      offsetof(struct clc_work_properties_data,
44                                               global_offset_x)),
45                          nir_dest_num_components(intr->dest),
46                          nir_dest_bit_size(intr->dest));
47   nir_ssa_def_rewrite_uses(&intr->dest.ssa, offset);
48   nir_instr_remove(&intr->instr);
49   return true;
50}
51
52static bool
53lower_load_work_dim(nir_builder *b, nir_intrinsic_instr *intr,
54                    nir_variable *var)
55{
56   b->cursor = nir_after_instr(&intr->instr);
57
58   nir_ssa_def *dim =
59      build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
60                          nir_imm_int(b,
61                                      offsetof(struct clc_work_properties_data,
62                                               work_dim)),
63                          nir_dest_num_components(intr->dest),
64                          nir_dest_bit_size(intr->dest));
65   nir_ssa_def_rewrite_uses(&intr->dest.ssa, dim);
66   nir_instr_remove(&intr->instr);
67   return true;
68}
69
70static bool
71lower_load_num_workgroups(nir_builder *b, nir_intrinsic_instr *intr,
72                          nir_variable *var)
73{
74   b->cursor = nir_after_instr(&intr->instr);
75
76   nir_ssa_def *count =
77      build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
78                         nir_imm_int(b,
79                                     offsetof(struct clc_work_properties_data,
80                                              group_count_total_x)),
81                         nir_dest_num_components(intr->dest),
82                         nir_dest_bit_size(intr->dest));
83   nir_ssa_def_rewrite_uses(&intr->dest.ssa, count);
84   nir_instr_remove(&intr->instr);
85   return true;
86}
87
88static bool
89lower_load_base_workgroup_id(nir_builder *b, nir_intrinsic_instr *intr,
90                             nir_variable *var)
91{
92   b->cursor = nir_after_instr(&intr->instr);
93
94   nir_ssa_def *offset =
95      build_load_ubo_dxil(b, nir_imm_int(b, var->data.binding),
96                         nir_imm_int(b,
97                                     offsetof(struct clc_work_properties_data,
98                                              group_id_offset_x)),
99                         nir_dest_num_components(intr->dest),
100                         nir_dest_bit_size(intr->dest));
101   nir_ssa_def_rewrite_uses(&intr->dest.ssa, offset);
102   nir_instr_remove(&intr->instr);
103   return true;
104}
105
106bool
107clc_nir_lower_system_values(nir_shader *nir, nir_variable *var)
108{
109   bool progress = false;
110
111   foreach_list_typed(nir_function, func, node, &nir->functions) {
112      if (!func->is_entrypoint)
113         continue;
114      assert(func->impl);
115
116      nir_builder b;
117      nir_builder_init(&b, func->impl);
118
119      nir_foreach_block(block, func->impl) {
120         nir_foreach_instr_safe(instr, block) {
121            if (instr->type != nir_instr_type_intrinsic)
122               continue;
123
124            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
125
126            switch (intr->intrinsic) {
127            case nir_intrinsic_load_base_global_invocation_id:
128               progress |= lower_load_base_global_invocation_id(&b, intr, var);
129               break;
130            case nir_intrinsic_load_work_dim:
131               progress |= lower_load_work_dim(&b, intr, var);
132               break;
133            case nir_intrinsic_load_num_workgroups:
134               lower_load_num_workgroups(&b, intr, var);
135               break;
136            case nir_intrinsic_load_base_workgroup_id:
137               lower_load_base_workgroup_id(&b, intr, var);
138               break;
139            default: break;
140            }
141         }
142      }
143   }
144
145   return progress;
146}
147
148static bool
149lower_load_kernel_input(nir_builder *b, nir_intrinsic_instr *intr,
150                        nir_variable *var)
151{
152   b->cursor = nir_before_instr(&intr->instr);
153
154   unsigned bit_size = nir_dest_bit_size(intr->dest);
155   enum glsl_base_type base_type;
156
157   switch (bit_size) {
158   case 64:
159      base_type = GLSL_TYPE_UINT64;
160      break;
161   case 32:
162      base_type = GLSL_TYPE_UINT;
163      break;
164    case 16:
165      base_type = GLSL_TYPE_UINT16;
166      break;
167    case 8:
168      base_type = GLSL_TYPE_UINT8;
169      break;
170   }
171
172   const struct glsl_type *type =
173      glsl_vector_type(base_type, nir_dest_num_components(intr->dest));
174   nir_ssa_def *ptr = nir_vec2(b, nir_imm_int(b, var->data.binding),
175                                  nir_u2u(b, intr->src[0].ssa, 32));
176   nir_deref_instr *deref = nir_build_deref_cast(b, ptr, nir_var_mem_ubo, type,
177                                                    bit_size / 8);
178   deref->cast.align_mul = nir_intrinsic_align_mul(intr);
179   deref->cast.align_offset = nir_intrinsic_align_offset(intr);
180
181   nir_ssa_def *result =
182      nir_load_deref(b, deref);
183   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
184   nir_instr_remove(&intr->instr);
185   return true;
186}
187
188bool
189clc_nir_lower_kernel_input_loads(nir_shader *nir, nir_variable *var)
190{
191   bool progress = false;
192
193   foreach_list_typed(nir_function, func, node, &nir->functions) {
194      if (!func->is_entrypoint)
195         continue;
196      assert(func->impl);
197
198      nir_builder b;
199      nir_builder_init(&b, func->impl);
200
201      nir_foreach_block(block, func->impl) {
202         nir_foreach_instr_safe(instr, block) {
203            if (instr->type != nir_instr_type_intrinsic)
204               continue;
205
206            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
207
208            if (intr->intrinsic == nir_intrinsic_load_kernel_input)
209               progress |= lower_load_kernel_input(&b, intr, var);
210         }
211      }
212   }
213
214   return progress;
215}
216
217
218static nir_variable *
219add_printf_var(struct nir_shader *nir, unsigned uav_id)
220{
221   /* This size is arbitrary. Minimum required per spec is 1MB */
222   const unsigned max_printf_size = 1 * 1024 * 1024;
223   const unsigned printf_array_size = max_printf_size / sizeof(unsigned);
224   nir_variable *var =
225      nir_variable_create(nir, nir_var_mem_ssbo,
226                          glsl_array_type(glsl_uint_type(), printf_array_size, sizeof(unsigned)),
227                          "printf");
228   var->data.binding = uav_id;
229   return var;
230}
231
232bool
233clc_lower_printf_base(nir_shader *nir, unsigned uav_id)
234{
235   nir_variable *printf_var = NULL;
236   nir_ssa_def *printf_deref = NULL;
237   nir_foreach_function(func, nir) {
238      nir_builder b;
239      nir_builder_init(&b, func->impl);
240      b.cursor = nir_before_instr(nir_block_first_instr(nir_start_block(func->impl)));
241      bool progress = false;
242
243      nir_foreach_block(block, func->impl) {
244         nir_foreach_instr_safe(instr, block) {
245            if (instr->type != nir_instr_type_intrinsic)
246               continue;
247            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
248            if (intrin->intrinsic != nir_intrinsic_load_printf_buffer_address)
249               continue;
250
251            if (!printf_var) {
252               printf_var = add_printf_var(nir, uav_id);
253               nir_deref_instr *deref = nir_build_deref_var(&b, printf_var);
254               printf_deref = &deref->dest.ssa;
255            }
256            nir_ssa_def_rewrite_uses(&intrin->dest.ssa, printf_deref);
257            progress = true;
258         }
259      }
260
261      if (progress)
262         nir_metadata_preserve(func->impl, nir_metadata_loop_analysis |
263                                           nir_metadata_block_index |
264                                           nir_metadata_dominance);
265      else
266         nir_metadata_preserve(func->impl, nir_metadata_all);
267   }
268
269   return printf_var != NULL;
270}
271
272static nir_variable *
273find_identical_const_sampler(nir_shader *nir, nir_variable *sampler)
274{
275   nir_foreach_variable_with_modes(uniform, nir, nir_var_uniform) {
276      if (!glsl_type_is_sampler(uniform->type) || !uniform->data.sampler.is_inline_sampler)
277         continue;
278      if (uniform->data.sampler.addressing_mode == sampler->data.sampler.addressing_mode &&
279          uniform->data.sampler.normalized_coordinates == sampler->data.sampler.normalized_coordinates &&
280          uniform->data.sampler.filter_mode == sampler->data.sampler.filter_mode)
281         return uniform;
282   }
283   unreachable("Should have at least found the input sampler");
284}
285
286static bool
287clc_nir_dedupe_const_samplers_instr(nir_builder *b,
288                                    nir_instr *instr,
289                                    void *cb_data)
290{
291   nir_shader *nir = cb_data;
292   if (instr->type != nir_instr_type_tex)
293      return false;
294
295   nir_tex_instr *tex = nir_instr_as_tex(instr);
296   int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
297   if (sampler_idx == -1)
298      return false;
299
300   nir_deref_instr *deref = nir_src_as_deref(tex->src[sampler_idx].src);
301   nir_variable *sampler = nir_deref_instr_get_variable(deref);
302   if (!sampler)
303      return false;
304
305   assert(sampler->data.mode == nir_var_uniform);
306
307   if (!sampler->data.sampler.is_inline_sampler)
308      return false;
309
310   nir_variable *replacement = find_identical_const_sampler(nir, sampler);
311   if (replacement == sampler)
312      return false;
313
314   b->cursor = nir_before_instr(&tex->instr);
315   nir_deref_instr *replacement_deref = nir_build_deref_var(b, replacement);
316   nir_instr_rewrite_src(&tex->instr, &tex->src[sampler_idx].src,
317                         nir_src_for_ssa(&replacement_deref->dest.ssa));
318   nir_deref_instr_remove_if_unused(deref);
319
320   return true;
321}
322
323bool
324clc_nir_dedupe_const_samplers(nir_shader *nir)
325{
326   return nir_shader_instructions_pass(nir,
327                                       clc_nir_dedupe_const_samplers_instr,
328                                       nir_metadata_block_index |
329                                       nir_metadata_dominance,
330                                       nir);
331}
332