1/*
2 * Copyright © 2020 Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "brw_kernel.h"
25#include "brw_nir.h"
26
27#include "compiler/nir/nir_builder.h"
28#include "compiler/spirv/nir_spirv.h"
29#include "dev/intel_debug.h"
30#include "util/u_atomic.h"
31
32static const nir_shader *
33load_clc_shader(struct brw_compiler *compiler, struct disk_cache *disk_cache,
34                const nir_shader_compiler_options *nir_options,
35                const struct spirv_to_nir_options *spirv_options)
36{
37   if (compiler->clc_shader)
38      return compiler->clc_shader;
39
40   nir_shader *nir =  nir_load_libclc_shader(64, disk_cache,
41                                             spirv_options, nir_options);
42   if (nir == NULL)
43      return NULL;
44
45   const nir_shader *old_nir =
46      p_atomic_cmpxchg(&compiler->clc_shader, NULL, nir);
47   if (old_nir == NULL) {
48      /* We won the race */
49      return nir;
50   } else {
51      /* Someone else built the shader first */
52      ralloc_free(nir);
53      return old_nir;
54   }
55}
56
57static void
58builder_init_new_impl(nir_builder *b, nir_function *func)
59{
60   nir_function_impl *impl = nir_function_impl_create(func);
61   nir_builder_init(b, impl);
62   b->cursor = nir_before_cf_list(&impl->body);
63}
64
65static void
66implement_atomic_builtin(nir_function *func, nir_intrinsic_op op,
67                         enum glsl_base_type data_base_type,
68                         nir_variable_mode mode)
69{
70   nir_builder b;
71   builder_init_new_impl(&b, func);
72
73   const struct glsl_type *data_type = glsl_scalar_type(data_base_type);
74
75   unsigned p = 0;
76
77   nir_deref_instr *ret = NULL;
78   if (nir_intrinsic_infos[op].has_dest) {
79      ret = nir_build_deref_cast(&b, nir_load_param(&b, p++),
80                                 nir_var_function_temp, data_type, 0);
81   }
82
83   nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b.shader, op);
84
85   for (unsigned i = 0; i < nir_intrinsic_infos[op].num_srcs; i++) {
86      nir_ssa_def *src = nir_load_param(&b, p++);
87      if (i == 0) {
88         /* The first source is our deref */
89         assert(nir_intrinsic_infos[op].src_components[i] == -1);
90         src = &nir_build_deref_cast(&b, src, mode, data_type, 0)->dest.ssa;
91      }
92      atomic->src[i] = nir_src_for_ssa(src);
93   }
94
95   if (nir_intrinsic_infos[op].has_dest) {
96      nir_ssa_dest_init_for_type(&atomic->instr, &atomic->dest,
97                                 data_type, NULL);
98   }
99
100   nir_builder_instr_insert(&b, &atomic->instr);
101
102   if (nir_intrinsic_infos[op].has_dest)
103      nir_store_deref(&b, ret, &atomic->dest.ssa, ~0);
104}
105
106static void
107implement_sub_group_ballot_builtin(nir_function *func)
108{
109   nir_builder b;
110   builder_init_new_impl(&b, func);
111
112   nir_deref_instr *ret =
113      nir_build_deref_cast(&b, nir_load_param(&b, 0),
114                           nir_var_function_temp, glsl_uint_type(), 0);
115   nir_ssa_def *cond = nir_load_param(&b, 1);
116
117   nir_intrinsic_instr *ballot =
118      nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot);
119   ballot->src[0] = nir_src_for_ssa(cond);
120   ballot->num_components = 1;
121   nir_ssa_dest_init(&ballot->instr, &ballot->dest, 1, 32, NULL);
122   nir_builder_instr_insert(&b, &ballot->instr);
123
124   nir_store_deref(&b, ret, &ballot->dest.ssa, ~0);
125}
126
127static bool
128implement_intel_builtins(nir_shader *nir)
129{
130   bool progress = false;
131
132   nir_foreach_function(func, nir) {
133      if (strcmp(func->name, "_Z10atomic_minPU3AS1Vff") == 0) {
134         /* float atom_min(__global float volatile *p, float val) */
135         implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmin,
136                                  GLSL_TYPE_FLOAT, nir_var_mem_global);
137         progress = true;
138      } else if (strcmp(func->name, "_Z10atomic_maxPU3AS1Vff") == 0) {
139         /* float atom_max(__global float volatile *p, float val) */
140         implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmax,
141                                  GLSL_TYPE_FLOAT, nir_var_mem_global);
142         progress = true;
143      } else if (strcmp(func->name, "_Z10atomic_minPU3AS3Vff") == 0) {
144         /* float atomic_min(__shared float volatile *, float) */
145         implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmin,
146                                  GLSL_TYPE_FLOAT, nir_var_mem_shared);
147         progress = true;
148      } else if (strcmp(func->name, "_Z10atomic_maxPU3AS3Vff") == 0) {
149         /* float atomic_max(__shared float volatile *, float) */
150         implement_atomic_builtin(func, nir_intrinsic_deref_atomic_fmax,
151                                  GLSL_TYPE_FLOAT, nir_var_mem_shared);
152         progress = true;
153      } else if (strcmp(func->name, "intel_sub_group_ballot") == 0) {
154         implement_sub_group_ballot_builtin(func);
155         progress = true;
156      }
157   }
158
159   nir_shader_preserve_all_metadata(nir);
160
161   return progress;
162}
163
164static bool
165lower_kernel_intrinsics(nir_shader *nir)
166{
167   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
168
169   bool progress = false;
170
171   unsigned kernel_sysvals_start = 0;
172   unsigned kernel_arg_start = sizeof(struct brw_kernel_sysvals);
173   nir->num_uniforms += kernel_arg_start;
174
175   nir_builder b;
176   nir_builder_init(&b, impl);
177
178   nir_foreach_block(block, impl) {
179      nir_foreach_instr_safe(instr, block) {
180         if (instr->type != nir_instr_type_intrinsic)
181            continue;
182
183         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
184         switch (intrin->intrinsic) {
185         case nir_intrinsic_load_kernel_input: {
186            b.cursor = nir_instr_remove(&intrin->instr);
187
188            nir_intrinsic_instr *load =
189               nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
190            load->num_components = intrin->num_components;
191            load->src[0] = nir_src_for_ssa(nir_u2u32(&b, intrin->src[0].ssa));
192            nir_intrinsic_set_base(load, kernel_arg_start);
193            nir_intrinsic_set_range(load, nir->num_uniforms);
194            nir_ssa_dest_init(&load->instr, &load->dest,
195                              intrin->dest.ssa.num_components,
196                              intrin->dest.ssa.bit_size, NULL);
197            nir_builder_instr_insert(&b, &load->instr);
198
199            nir_ssa_def_rewrite_uses(&intrin->dest.ssa, &load->dest.ssa);
200            progress = true;
201            break;
202         }
203
204         case nir_intrinsic_load_constant_base_ptr: {
205            b.cursor = nir_instr_remove(&intrin->instr);
206            nir_ssa_def *const_data_base_addr = nir_pack_64_2x32_split(&b,
207               nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_LOW),
208               nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_HIGH));
209            nir_ssa_def_rewrite_uses(&intrin->dest.ssa, const_data_base_addr);
210            progress = true;
211            break;
212         }
213
214         case nir_intrinsic_load_num_workgroups: {
215            b.cursor = nir_instr_remove(&intrin->instr);
216
217            nir_intrinsic_instr *load =
218               nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
219            load->num_components = 3;
220            load->src[0] = nir_src_for_ssa(nir_imm_int(&b, 0));
221            nir_intrinsic_set_base(load, kernel_sysvals_start +
222               offsetof(struct brw_kernel_sysvals, num_work_groups));
223            nir_intrinsic_set_range(load, 3 * 4);
224            nir_ssa_dest_init(&load->instr, &load->dest, 3, 32, NULL);
225            nir_builder_instr_insert(&b, &load->instr);
226
227            /* We may need to do a bit-size cast here */
228            nir_ssa_def *num_work_groups =
229               nir_u2u(&b, &load->dest.ssa, intrin->dest.ssa.bit_size);
230
231            nir_ssa_def_rewrite_uses(&intrin->dest.ssa, num_work_groups);
232            progress = true;
233            break;
234         }
235
236         default:
237            break;
238         }
239      }
240   }
241
242   if (progress) {
243      nir_metadata_preserve(impl, nir_metadata_block_index |
244                                  nir_metadata_dominance);
245   } else {
246      nir_metadata_preserve(impl, nir_metadata_all);
247   }
248
249   return progress;
250}
251
252bool
253brw_kernel_from_spirv(struct brw_compiler *compiler,
254                      struct disk_cache *disk_cache,
255                      struct brw_kernel *kernel,
256                      void *log_data, void *mem_ctx,
257                      const uint32_t *spirv, size_t spirv_size,
258                      const char *entrypoint_name,
259                      char **error_str)
260{
261   const struct intel_device_info *devinfo = compiler->devinfo;
262   const nir_shader_compiler_options *nir_options =
263      compiler->nir_options[MESA_SHADER_KERNEL];
264
265   struct spirv_to_nir_options spirv_options = {
266      .environment = NIR_SPIRV_OPENCL,
267      .caps = {
268         .address = true,
269         .float16 = devinfo->ver >= 8,
270         .float64 = devinfo->ver >= 8,
271         .groups = true,
272         .image_write_without_format = true,
273         .int8 = devinfo->ver >= 8,
274         .int16 = devinfo->ver >= 8,
275         .int64 = devinfo->ver >= 8,
276         .int64_atomics = devinfo->ver >= 9,
277         .kernel = true,
278         .linkage = true, /* We receive linked kernel from clc */
279         .float_controls = devinfo->ver >= 8,
280         .generic_pointers = true,
281         .storage_8bit = devinfo->ver >= 8,
282         .storage_16bit = devinfo->ver >= 8,
283         .subgroup_arithmetic = true,
284         .subgroup_basic = true,
285         .subgroup_ballot = true,
286         .subgroup_dispatch = true,
287         .subgroup_quad = true,
288         .subgroup_shuffle = true,
289         .subgroup_vote = true,
290
291         .intel_subgroup_shuffle = true,
292         .intel_subgroup_buffer_block_io = true,
293      },
294      .shared_addr_format = nir_address_format_62bit_generic,
295      .global_addr_format = nir_address_format_62bit_generic,
296      .temp_addr_format = nir_address_format_62bit_generic,
297      .constant_addr_format = nir_address_format_64bit_global,
298   };
299
300   spirv_options.clc_shader = load_clc_shader(compiler, disk_cache,
301                                              nir_options, &spirv_options);
302
303   assert(spirv_size % 4 == 0);
304   nir_shader *nir =
305      spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
306                   entrypoint_name, &spirv_options, nir_options);
307   nir_validate_shader(nir, "after spirv_to_nir");
308   nir_validate_ssa_dominance(nir, "after spirv_to_nir");
309   ralloc_steal(mem_ctx, nir);
310   nir->info.name = ralloc_strdup(nir, entrypoint_name);
311
312   if (INTEL_DEBUG(DEBUG_CS)) {
313      /* Re-index SSA defs so we print more sensible numbers. */
314      nir_foreach_function(function, nir) {
315         if (function->impl)
316            nir_index_ssa_defs(function->impl);
317      }
318
319      fprintf(stderr, "NIR (from SPIR-V) for kernel\n");
320      nir_print_shader(nir, stderr);
321   }
322
323   NIR_PASS_V(nir, implement_intel_builtins);
324   NIR_PASS_V(nir, nir_lower_libclc, spirv_options.clc_shader);
325
326   /* We have to lower away local constant initializers right before we
327    * inline functions.  That way they get properly initialized at the top
328    * of the function and not at the top of its caller.
329    */
330   NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
331   NIR_PASS_V(nir, nir_lower_returns);
332   NIR_PASS_V(nir, nir_inline_functions);
333   NIR_PASS_V(nir, nir_copy_prop);
334   NIR_PASS_V(nir, nir_opt_deref);
335
336   /* Pick off the single entrypoint that we want */
337   nir_remove_non_entrypoints(nir);
338
339   /* Now that we've deleted all but the main function, we can go ahead and
340    * lower the rest of the constant initializers.  We do this here so that
341    * nir_remove_dead_variables and split_per_member_structs below see the
342    * corresponding stores.
343    */
344   NIR_PASS_V(nir, nir_lower_variable_initializers, ~0);
345
346   /* LLVM loves take advantage of the fact that vec3s in OpenCL are 16B
347    * aligned and so it can just read/write them as vec4s.  This results in a
348    * LOT of vec4->vec3 casts on loads and stores.  One solution to this
349    * problem is to get rid of all vec3 variables.
350    */
351   NIR_PASS_V(nir, nir_lower_vec3_to_vec4,
352              nir_var_shader_temp | nir_var_function_temp |
353              nir_var_mem_shared | nir_var_mem_global|
354              nir_var_mem_constant);
355
356   /* We assign explicit types early so that the optimizer can take advantage
357    * of that information and hopefully get rid of some of our memcpys.
358    */
359   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
360              nir_var_uniform |
361              nir_var_shader_temp | nir_var_function_temp |
362              nir_var_mem_shared | nir_var_mem_global,
363              glsl_get_cl_type_size_align);
364
365   brw_preprocess_nir(compiler, nir, NULL);
366
367   int max_arg_idx = -1;
368   nir_foreach_uniform_variable(var, nir) {
369      assert(var->data.location < 256);
370      max_arg_idx = MAX2(max_arg_idx, var->data.location);
371   }
372
373   kernel->args_size = nir->num_uniforms;
374   kernel->arg_count = max_arg_idx + 1;
375
376   /* No bindings */
377   struct brw_kernel_arg_desc *args =
378      rzalloc_array(mem_ctx, struct brw_kernel_arg_desc, kernel->arg_count);
379   kernel->args = args;
380
381   nir_foreach_uniform_variable(var, nir) {
382      struct brw_kernel_arg_desc arg_desc = {
383         .offset = var->data.driver_location,
384         .size = glsl_get_explicit_size(var->type, false),
385      };
386      assert(arg_desc.offset + arg_desc.size <= nir->num_uniforms);
387
388      assert(var->data.location >= 0);
389      args[var->data.location] = arg_desc;
390   }
391
392   NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_all, NULL);
393
394   /* Lower again, this time after dead-variables to get more compact variable
395    * layouts.
396    */
397   nir->global_mem_size = 0;
398   nir->scratch_size = 0;
399   nir->info.shared_size = 0;
400   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
401              nir_var_shader_temp | nir_var_function_temp |
402              nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant,
403              glsl_get_cl_type_size_align);
404   if (nir->constant_data_size > 0) {
405      assert(nir->constant_data == NULL);
406      nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
407      nir_gather_explicit_io_initializers(nir, nir->constant_data,
408                                          nir->constant_data_size,
409                                          nir_var_mem_constant);
410   }
411
412   if (INTEL_DEBUG(DEBUG_CS)) {
413      /* Re-index SSA defs so we print more sensible numbers. */
414      nir_foreach_function(function, nir) {
415         if (function->impl)
416            nir_index_ssa_defs(function->impl);
417      }
418
419      fprintf(stderr, "NIR (before I/O lowering) for kernel\n");
420      nir_print_shader(nir, stderr);
421   }
422
423   NIR_PASS_V(nir, nir_lower_memcpy);
424
425   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
426              nir_address_format_64bit_global);
427
428   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
429              nir_address_format_32bit_offset_as_64bit);
430
431   NIR_PASS_V(nir, nir_lower_explicit_io,
432              nir_var_shader_temp | nir_var_function_temp |
433              nir_var_mem_shared | nir_var_mem_global,
434              nir_address_format_62bit_generic);
435
436   NIR_PASS_V(nir, nir_lower_frexp);
437   NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
438
439   NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics);
440   NIR_PASS_V(nir, lower_kernel_intrinsics);
441
442   struct brw_cs_prog_key key = { };
443
444   memset(&kernel->prog_data, 0, sizeof(kernel->prog_data));
445   kernel->prog_data.base.nr_params = DIV_ROUND_UP(nir->num_uniforms, 4);
446
447   struct brw_compile_cs_params params = {
448      .nir = nir,
449      .key = &key,
450      .prog_data = &kernel->prog_data,
451      .stats = kernel->stats,
452      .log_data = log_data,
453   };
454
455   kernel->code = brw_compile_cs(compiler, mem_ctx, &params);
456
457   if (error_str)
458      *error_str = params.error_str;
459
460   return kernel->code != NULL;
461}
462