1//
2// Copyright 2012-2016 Francisco Jerez
3// Copyright 2012-2016 Advanced Micro Devices, Inc.
4// Copyright 2014-2016 Jan Vesely
5// Copyright 2014-2015 Serge Martin
6// Copyright 2015 Zoltan Gilian
7//
8// Permission is hereby granted, free of charge, to any person obtaining a
9// copy of this software and associated documentation files (the "Software"),
10// to deal in the Software without restriction, including without limitation
11// the rights to use, copy, modify, merge, publish, distribute, sublicense,
12// and/or sell copies of the Software, and to permit persons to whom the
13// Software is furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24// OTHER DEALINGS IN THE SOFTWARE.
25
26#include <sstream>
27#include <mutex>
28
29#include <llvm/ADT/ArrayRef.h>
30#include <llvm/IR/DiagnosticPrinter.h>
31#include <llvm/IR/DiagnosticInfo.h>
32#include <llvm/IR/LLVMContext.h>
33#include <llvm/IR/Type.h>
34#include <llvm/Support/raw_ostream.h>
35#include <llvm/Bitcode/BitcodeWriter.h>
36#include <llvm/Bitcode/BitcodeReader.h>
37#include <llvm-c/Core.h>
38#include <llvm-c/Target.h>
39#include <LLVMSPIRVLib/LLVMSPIRVLib.h>
40
41#include <clang/CodeGen/CodeGenAction.h>
42#include <clang/Lex/PreprocessorOptions.h>
43#include <clang/Frontend/CompilerInstance.h>
44#include <clang/Frontend/TextDiagnosticBuffer.h>
45#include <clang/Frontend/TextDiagnosticPrinter.h>
46#include <clang/Basic/TargetInfo.h>
47
48#include <spirv-tools/libspirv.hpp>
49#include <spirv-tools/linker.hpp>
50#include <spirv-tools/optimizer.hpp>
51
52#include "util/macros.h"
53#include "glsl_types.h"
54
55#include "spirv.h"
56
57#ifdef USE_STATIC_OPENCL_C_H
58#if LLVM_VERSION_MAJOR < 15
59#include "opencl-c.h.h"
60#endif
61#include "opencl-c-base.h.h"
62#endif
63
64#include "clc_helpers.h"
65
66/* Use the highest version of SPIRV supported by SPIRV-Tools. */
67constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
68
69constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
70
71using ::llvm::Function;
72using ::llvm::LLVMContext;
73using ::llvm::Module;
74using ::llvm::raw_string_ostream;
75
76static void
77llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
78   const clc_logger *logger = static_cast<clc_logger *>(data);
79
80   std::string log;
81   raw_string_ostream os { log };
82   ::llvm::DiagnosticPrinterRawOStream printer { os };
83   di.print(printer);
84
85   clc_error(logger, "%s", log.c_str());
86}
87
88class SPIRVKernelArg {
89public:
90   SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
91                                                  addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
92                                                  accessQualifier(0),
93                                                  typeQualifier(0) { }
94   ~SPIRVKernelArg() { }
95
96   uint32_t id;
97   uint32_t typeId;
98   std::string name;
99   std::string typeName;
100   enum clc_kernel_arg_address_qualifier addrQualifier;
101   unsigned accessQualifier;
102   unsigned typeQualifier;
103};
104
105class SPIRVKernelInfo {
106public:
107   SPIRVKernelInfo(uint32_t fid, const char *nm)
108      : funcId(fid), name(nm), vecHint(0), localSize(), localSizeHint() { }
109   ~SPIRVKernelInfo() { }
110
111   uint32_t funcId;
112   std::string name;
113   std::vector<SPIRVKernelArg> args;
114   unsigned vecHint;
115   unsigned localSize[3];
116   unsigned localSizeHint[3];
117};
118
119class SPIRVKernelParser {
120public:
121   SPIRVKernelParser() : curKernel(NULL)
122   {
123      ctx = spvContextCreate(spirv_target);
124   }
125
126   ~SPIRVKernelParser()
127   {
128     spvContextDestroy(ctx);
129   }
130
131   void parseEntryPoint(const spv_parsed_instruction_t *ins)
132   {
133      assert(ins->num_operands >= 3);
134
135      const spv_parsed_operand_t *op = &ins->operands[1];
136
137      assert(op->type == SPV_OPERAND_TYPE_ID);
138
139      uint32_t funcId = ins->words[op->offset];
140
141      for (auto &iter : kernels) {
142         if (funcId == iter.funcId)
143            return;
144      }
145
146      op = &ins->operands[2];
147      assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
148      const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
149
150      kernels.push_back(SPIRVKernelInfo(funcId, name));
151   }
152
153   void parseFunction(const spv_parsed_instruction_t *ins)
154   {
155      assert(ins->num_operands == 4);
156
157      const spv_parsed_operand_t *op = &ins->operands[1];
158
159      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
160
161      uint32_t funcId = ins->words[op->offset];
162
163      for (auto &kernel : kernels) {
164         if (funcId == kernel.funcId && !kernel.args.size()) {
165            curKernel = &kernel;
166	    return;
167         }
168      }
169   }
170
171   void parseFunctionParam(const spv_parsed_instruction_t *ins)
172   {
173      const spv_parsed_operand_t *op;
174      uint32_t id, typeId;
175
176      if (!curKernel)
177         return;
178
179      assert(ins->num_operands == 2);
180      op = &ins->operands[0];
181      assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
182      typeId = ins->words[op->offset];
183      op = &ins->operands[1];
184      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
185      id = ins->words[op->offset];
186      curKernel->args.push_back(SPIRVKernelArg(id, typeId));
187   }
188
189   void parseName(const spv_parsed_instruction_t *ins)
190   {
191      const spv_parsed_operand_t *op;
192      const char *name;
193      uint32_t id;
194
195      assert(ins->num_operands == 2);
196
197      op = &ins->operands[0];
198      assert(op->type == SPV_OPERAND_TYPE_ID);
199      id = ins->words[op->offset];
200      op = &ins->operands[1];
201      assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
202      name = reinterpret_cast<const char *>(ins->words + op->offset);
203
204      for (auto &kernel : kernels) {
205         for (auto &arg : kernel.args) {
206            if (arg.id == id && arg.name.empty()) {
207              arg.name = name;
208              break;
209	    }
210         }
211      }
212   }
213
214   void parseTypePointer(const spv_parsed_instruction_t *ins)
215   {
216      enum clc_kernel_arg_address_qualifier addrQualifier;
217      uint32_t typeId, storageClass;
218      const spv_parsed_operand_t *op;
219
220      assert(ins->num_operands == 3);
221
222      op = &ins->operands[0];
223      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
224      typeId = ins->words[op->offset];
225
226      op = &ins->operands[1];
227      assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
228      storageClass = ins->words[op->offset];
229      switch (storageClass) {
230      case SpvStorageClassCrossWorkgroup:
231         addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
232         break;
233      case SpvStorageClassWorkgroup:
234         addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
235         break;
236      case SpvStorageClassUniformConstant:
237         addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
238         break;
239      default:
240         addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
241         break;
242      }
243
244      for (auto &kernel : kernels) {
245	 for (auto &arg : kernel.args) {
246            if (arg.typeId == typeId) {
247               arg.addrQualifier = addrQualifier;
248               if (addrQualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT)
249                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
250            }
251         }
252      }
253   }
254
255   void parseOpString(const spv_parsed_instruction_t *ins)
256   {
257      const spv_parsed_operand_t *op;
258      std::string str;
259
260      assert(ins->num_operands == 2);
261
262      op = &ins->operands[1];
263      assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
264      str = reinterpret_cast<const char *>(ins->words + op->offset);
265
266      size_t start = 0;
267      enum class string_type {
268         arg_type,
269         arg_type_qual,
270      } str_type;
271
272      if (str.find("kernel_arg_type.") == 0) {
273         start = sizeof("kernel_arg_type.") - 1;
274         str_type = string_type::arg_type;
275      } else if (str.find("kernel_arg_type_qual.") == 0) {
276         start = sizeof("kernel_arg_type_qual.") - 1;
277         str_type = string_type::arg_type_qual;
278      } else {
279         return;
280      }
281
282      for (auto &kernel : kernels) {
283         size_t pos;
284
285	 pos = str.find(kernel.name, start);
286         if (pos == std::string::npos ||
287             pos != start || str[start + kernel.name.size()] != '.')
288            continue;
289
290	 pos = start + kernel.name.size();
291         if (str[pos++] != '.')
292            continue;
293
294         for (auto &arg : kernel.args) {
295            if (arg.name.empty())
296               break;
297
298            size_t entryEnd = str.find(',', pos);
299	    if (entryEnd == std::string::npos)
300               break;
301
302            std::string entryVal = str.substr(pos, entryEnd - pos);
303            pos = entryEnd + 1;
304
305            if (str_type == string_type::arg_type) {
306               arg.typeName = std::move(entryVal);
307            } else if (str_type == string_type::arg_type_qual) {
308               if (entryVal.find("const") != std::string::npos)
309                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
310            }
311         }
312      }
313   }
314
315   void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
316   {
317      auto iter = decorationGroups.find(id);
318      if (iter != decorationGroups.end()) {
319         for (uint32_t entry : iter->second)
320            applyDecoration(entry, ins);
321         return;
322      }
323
324      const spv_parsed_operand_t *op;
325      uint32_t decoration;
326
327      assert(ins->num_operands >= 2);
328
329      op = &ins->operands[1];
330      assert(op->type == SPV_OPERAND_TYPE_DECORATION);
331      decoration = ins->words[op->offset];
332
333      if (decoration == SpvDecorationSpecId) {
334         uint32_t spec_id = ins->words[ins->operands[2].offset];
335         for (auto &c : specConstants) {
336            if (c.second.id == spec_id) {
337               assert(c.first == id);
338               return;
339            }
340         }
341         specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
342         return;
343      }
344
345      for (auto &kernel : kernels) {
346         for (auto &arg : kernel.args) {
347            if (arg.id == id) {
348               switch (decoration) {
349               case SpvDecorationVolatile:
350                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
351                  break;
352               case SpvDecorationConstant:
353                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
354                  break;
355               case SpvDecorationRestrict:
356                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
357                  break;
358               case SpvDecorationFuncParamAttr:
359                  op = &ins->operands[2];
360                  assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
361                  switch (ins->words[op->offset]) {
362                  case SpvFunctionParameterAttributeNoAlias:
363                     arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
364                     break;
365                  case SpvFunctionParameterAttributeNoWrite:
366                     arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
367                     break;
368                  }
369                  break;
370               }
371            }
372
373         }
374      }
375   }
376
377   void parseOpDecorate(const spv_parsed_instruction_t *ins)
378   {
379      const spv_parsed_operand_t *op;
380      uint32_t id;
381
382      assert(ins->num_operands >= 2);
383
384      op = &ins->operands[0];
385      assert(op->type == SPV_OPERAND_TYPE_ID);
386      id = ins->words[op->offset];
387
388      applyDecoration(id, ins);
389   }
390
391   void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
392   {
393      assert(ins->num_operands >= 2);
394
395      const spv_parsed_operand_t *op = &ins->operands[0];
396      assert(op->type == SPV_OPERAND_TYPE_ID);
397      uint32_t groupId = ins->words[op->offset];
398
399      auto lowerBound = decorationGroups.lower_bound(groupId);
400      if (lowerBound != decorationGroups.end() &&
401          lowerBound->first == groupId)
402         // Group already filled out
403         return;
404
405      auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
406      auto& vec = iter->second;
407      vec.reserve(ins->num_operands - 1);
408      for (uint32_t i = 1; i < ins->num_operands; ++i) {
409         op = &ins->operands[i];
410         assert(op->type == SPV_OPERAND_TYPE_ID);
411         vec.push_back(ins->words[op->offset]);
412      }
413   }
414
415   void parseOpTypeImage(const spv_parsed_instruction_t *ins)
416   {
417      const spv_parsed_operand_t *op;
418      uint32_t typeId;
419      unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
420
421      op = &ins->operands[0];
422      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
423      typeId = ins->words[op->offset];
424
425      if (ins->num_operands >= 9) {
426         op = &ins->operands[8];
427         assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
428         switch (ins->words[op->offset]) {
429         case SpvAccessQualifierReadOnly:
430            accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
431            break;
432         case SpvAccessQualifierWriteOnly:
433            accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
434            break;
435         case SpvAccessQualifierReadWrite:
436            accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
437               CLC_KERNEL_ARG_ACCESS_READ;
438            break;
439         }
440      }
441
442      for (auto &kernel : kernels) {
443	 for (auto &arg : kernel.args) {
444            if (arg.typeId == typeId) {
445               arg.accessQualifier = accessQualifier;
446               arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
447            }
448         }
449      }
450   }
451
452   void parseExecutionMode(const spv_parsed_instruction_t *ins)
453   {
454      uint32_t executionMode = ins->words[ins->operands[1].offset];
455      uint32_t funcId = ins->words[ins->operands[0].offset];
456
457      for (auto& kernel : kernels) {
458         if (kernel.funcId == funcId) {
459            switch (executionMode) {
460            case SpvExecutionModeVecTypeHint:
461               kernel.vecHint = ins->words[ins->operands[2].offset];
462               break;
463            case SpvExecutionModeLocalSize:
464               kernel.localSize[0] = ins->words[ins->operands[2].offset];
465               kernel.localSize[1] = ins->words[ins->operands[3].offset];
466               kernel.localSize[2] = ins->words[ins->operands[4].offset];
467            case SpvExecutionModeLocalSizeHint:
468               kernel.localSizeHint[0] = ins->words[ins->operands[2].offset];
469               kernel.localSizeHint[1] = ins->words[ins->operands[3].offset];
470               kernel.localSizeHint[2] = ins->words[ins->operands[4].offset];
471            default:
472               return;
473            }
474         }
475      }
476   }
477
478   void parseLiteralType(const spv_parsed_instruction_t *ins)
479   {
480      uint32_t typeId = ins->words[ins->operands[0].offset];
481      auto& literalType = literalTypes[typeId];
482      switch (ins->opcode) {
483      case SpvOpTypeBool:
484         literalType = CLC_SPEC_CONSTANT_BOOL;
485         break;
486      case SpvOpTypeFloat: {
487         uint32_t sizeInBits = ins->words[ins->operands[1].offset];
488         switch (sizeInBits) {
489         case 32:
490            literalType = CLC_SPEC_CONSTANT_FLOAT;
491            break;
492         case 64:
493            literalType = CLC_SPEC_CONSTANT_DOUBLE;
494            break;
495         case 16:
496            /* Can't be used for a spec constant */
497            break;
498         default:
499            unreachable("Unexpected float bit size");
500         }
501         break;
502      }
503      case SpvOpTypeInt: {
504         uint32_t sizeInBits = ins->words[ins->operands[1].offset];
505         bool isSigned = ins->words[ins->operands[2].offset];
506         if (isSigned) {
507            switch (sizeInBits) {
508            case 8:
509               literalType = CLC_SPEC_CONSTANT_INT8;
510               break;
511            case 16:
512               literalType = CLC_SPEC_CONSTANT_INT16;
513               break;
514            case 32:
515               literalType = CLC_SPEC_CONSTANT_INT32;
516               break;
517            case 64:
518               literalType = CLC_SPEC_CONSTANT_INT64;
519               break;
520            default:
521               unreachable("Unexpected int bit size");
522            }
523         } else {
524            switch (sizeInBits) {
525            case 8:
526               literalType = CLC_SPEC_CONSTANT_UINT8;
527               break;
528            case 16:
529               literalType = CLC_SPEC_CONSTANT_UINT16;
530               break;
531            case 32:
532               literalType = CLC_SPEC_CONSTANT_UINT32;
533               break;
534            case 64:
535               literalType = CLC_SPEC_CONSTANT_UINT64;
536               break;
537            default:
538               unreachable("Unexpected uint bit size");
539            }
540         }
541         break;
542      }
543      default:
544         unreachable("Unexpected type opcode");
545      }
546   }
547
548   void parseSpecConstant(const spv_parsed_instruction_t *ins)
549   {
550      uint32_t id = ins->result_id;
551      for (auto& c : specConstants) {
552         if (c.first == id) {
553            auto& data = c.second;
554            switch (ins->opcode) {
555            case SpvOpSpecConstant: {
556               uint32_t typeId = ins->words[ins->operands[0].offset];
557
558               // This better be an integer or float type
559               auto typeIter = literalTypes.find(typeId);
560               assert(typeIter != literalTypes.end());
561
562               data.type = typeIter->second;
563               break;
564            }
565            case SpvOpSpecConstantFalse:
566            case SpvOpSpecConstantTrue:
567               data.type = CLC_SPEC_CONSTANT_BOOL;
568               break;
569            default:
570               unreachable("Composites and Ops are not directly specializable.");
571            }
572         }
573      }
574   }
575
576   static spv_result_t
577   parseInstruction(void *data, const spv_parsed_instruction_t *ins)
578   {
579      SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
580
581      switch (ins->opcode) {
582      case SpvOpName:
583         parser->parseName(ins);
584         break;
585      case SpvOpEntryPoint:
586         parser->parseEntryPoint(ins);
587         break;
588      case SpvOpFunction:
589         parser->parseFunction(ins);
590         break;
591      case SpvOpFunctionParameter:
592         parser->parseFunctionParam(ins);
593         break;
594      case SpvOpFunctionEnd:
595      case SpvOpLabel:
596         parser->curKernel = NULL;
597         break;
598      case SpvOpTypePointer:
599         parser->parseTypePointer(ins);
600         break;
601      case SpvOpTypeImage:
602         parser->parseOpTypeImage(ins);
603         break;
604      case SpvOpString:
605         parser->parseOpString(ins);
606         break;
607      case SpvOpDecorate:
608         parser->parseOpDecorate(ins);
609         break;
610      case SpvOpGroupDecorate:
611         parser->parseOpGroupDecorate(ins);
612         break;
613      case SpvOpExecutionMode:
614         parser->parseExecutionMode(ins);
615         break;
616      case SpvOpTypeBool:
617      case SpvOpTypeInt:
618      case SpvOpTypeFloat:
619         parser->parseLiteralType(ins);
620         break;
621      case SpvOpSpecConstant:
622      case SpvOpSpecConstantFalse:
623      case SpvOpSpecConstantTrue:
624         parser->parseSpecConstant(ins);
625         break;
626      default:
627         break;
628      }
629
630      return SPV_SUCCESS;
631   }
632
633   bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
634   {
635      /* 3 passes should be enough to retrieve all kernel information:
636       * 1st pass: all entry point name and number of args
637       * 2nd pass: argument names and type names
638       * 3rd pass: pointer type names
639       */
640      for (unsigned pass = 0; pass < 3; pass++) {
641         spv_diagnostic diagnostic = NULL;
642         auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
643                                      static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
644                                      NULL, parseInstruction, &diagnostic);
645
646         if (result != SPV_SUCCESS) {
647            if (diagnostic && logger)
648               logger->error(logger->priv, diagnostic->error);
649            return false;
650         }
651      }
652
653      return true;
654   }
655
656   std::vector<SPIRVKernelInfo> kernels;
657   std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
658   std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
659   std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
660   SPIRVKernelInfo *curKernel;
661   spv_context ctx;
662};
663
664bool
665clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
666                           const struct clc_kernel_info **out_kernels,
667                           unsigned *num_kernels,
668                           const struct clc_parsed_spec_constant **out_spec_constants,
669                           unsigned *num_spec_constants,
670                           const struct clc_logger *logger)
671{
672   struct clc_kernel_info *kernels;
673   struct clc_parsed_spec_constant *spec_constants = NULL;
674
675   SPIRVKernelParser parser;
676
677   if (!parser.parseBinary(*spvbin, logger))
678      return false;
679
680   *num_kernels = parser.kernels.size();
681   *num_spec_constants = parser.specConstants.size();
682   if (!*num_kernels)
683      return false;
684
685   kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
686                                                               sizeof(*kernels)));
687   assert(kernels);
688   for (unsigned i = 0; i < parser.kernels.size(); i++) {
689      kernels[i].name = strdup(parser.kernels[i].name.c_str());
690      kernels[i].num_args = parser.kernels[i].args.size();
691      kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
692      kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
693      memcpy(kernels[i].local_size, parser.kernels[i].localSize, sizeof(kernels[i].local_size));
694      memcpy(kernels[i].local_size_hint, parser.kernels[i].localSizeHint, sizeof(kernels[i].local_size_hint));
695      if (!kernels[i].num_args)
696         continue;
697
698      struct clc_kernel_arg *args;
699
700      args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
701                                                       sizeof(*kernels->args)));
702      kernels[i].args = args;
703      assert(args);
704      for (unsigned j = 0; j < kernels[i].num_args; j++) {
705         if (!parser.kernels[i].args[j].name.empty())
706            args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
707         args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
708         args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
709         args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
710         args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
711      }
712   }
713
714   if (*num_spec_constants) {
715      spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
716                                                                                  sizeof(*spec_constants)));
717      assert(spec_constants);
718
719      for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
720         spec_constants[i] = parser.specConstants[i].second;
721      }
722   }
723
724   *out_kernels = kernels;
725   *out_spec_constants = spec_constants;
726
727   return true;
728}
729
730void
731clc_free_kernels_info(const struct clc_kernel_info *kernels,
732                      unsigned num_kernels)
733{
734   if (!kernels)
735      return;
736
737   for (unsigned i = 0; i < num_kernels; i++) {
738      if (kernels[i].args) {
739         for (unsigned j = 0; j < kernels[i].num_args; j++) {
740            free((void *)kernels[i].args[j].name);
741            free((void *)kernels[i].args[j].type_name);
742         }
743      }
744      free((void *)kernels[i].name);
745   }
746
747   free((void *)kernels);
748}
749
750static std::unique_ptr<::llvm::Module>
751clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
752                           const struct clc_compile_args *args,
753                           const struct clc_logger *logger)
754{
755   std::string diag_log_str;
756   raw_string_ostream diag_log_stream { diag_log_str };
757
758   std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
759
760   clang::DiagnosticsEngine diag {
761      new clang::DiagnosticIDs,
762      new clang::DiagnosticOptions,
763      new clang::TextDiagnosticPrinter(diag_log_stream,
764                                       &c->getDiagnosticOpts())
765   };
766
767   std::vector<const char *> clang_opts = {
768      args->source.name,
769      "-triple", "spir64-unknown-unknown",
770      // By default, clang prefers to use modules to pull in the default headers,
771      // which doesn't work with our technique of embedding the headers in our binary
772#if LLVM_VERSION_MAJOR >= 15
773      "-fdeclare-opencl-builtins",
774#else
775      "-finclude-default-header",
776#endif
777#if LLVM_VERSION_MAJOR >= 15
778      "-no-opaque-pointers",
779#endif
780      // Add a default CL compiler version. Clang will pick the last one specified
781      // on the command line, so the app can override this one.
782      "-cl-std=cl1.2",
783      // The LLVM-SPIRV-Translator doesn't support memset with variable size
784      "-fno-builtin-memset",
785      // LLVM's optimizations can produce code that the translator can't translate
786      "-O0",
787      // Ensure inline functions are actually emitted
788      "-fgnu89-inline"
789   };
790   // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
791   // being provided by the caller here.
792   clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
793
794   if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
795#if LLVM_VERSION_MAJOR >= 10
796                                                  clang_opts,
797#else
798                                                  clang_opts.data(),
799                                                  clang_opts.data() + clang_opts.size(),
800#endif
801                                                  diag)) {
802      clc_error(logger, "Couldn't create Clang invocation.\n");
803      return {};
804   }
805
806   if (diag.hasErrorOccurred()) {
807      clc_error(logger, "%sErrors occurred during Clang invocation.\n",
808                diag_log_str.c_str());
809      return {};
810   }
811
812   // This is a workaround for a Clang bug which causes the number
813   // of warnings and errors to be printed to stderr.
814   // http://www.llvm.org/bugs/show_bug.cgi?id=19735
815   c->getDiagnosticOpts().ShowCarets = false;
816
817   c->createDiagnostics(new clang::TextDiagnosticPrinter(
818                           diag_log_stream,
819                           &c->getDiagnosticOpts()));
820
821   c->setTarget(clang::TargetInfo::CreateTargetInfo(
822                   c->getDiagnostics(), c->getInvocation().TargetOpts));
823
824   c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
825
826#ifdef USE_STATIC_OPENCL_C_H
827   c->getHeaderSearchOpts().UseBuiltinIncludes = false;
828   c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
829
830   // Add opencl-c generic search path
831   {
832      ::llvm::SmallString<128> system_header_path;
833      ::llvm::sys::path::system_temp_directory(true, system_header_path);
834      ::llvm::sys::path::append(system_header_path, "openclon12");
835      c->getHeaderSearchOpts().AddPath(system_header_path.str(),
836                                       clang::frontend::Angled,
837                                       false, false);
838
839#if LLVM_VERSION_MAJOR < 15
840      ::llvm::sys::path::append(system_header_path, "opencl-c.h");
841      c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
842         ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
843#endif
844
845      ::llvm::sys::path::remove_filename(system_header_path);
846      ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
847      c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
848         ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
849   }
850#else
851   c->getHeaderSearchOpts().UseBuiltinIncludes = true;
852   c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
853   c->getHeaderSearchOpts().ResourceDir = CLANG_RESOURCE_DIR;
854
855   // Add opencl-c generic search path
856   c->getHeaderSearchOpts().AddPath(CLANG_RESOURCE_DIR,
857                                    clang::frontend::Angled,
858                                    false, false);
859   // Add opencl include
860#if LLVM_VERSION_MAJOR >= 15
861   c->getPreprocessorOpts().Includes.push_back("opencl-c-base.h");
862#else
863   c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
864#endif
865#endif
866
867#if LLVM_VERSION_MAJOR >= 14
868   c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("-all");
869   c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_byte_addressable_store");
870   c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_base_atomics");
871   c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_extended_atomics");
872   c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_base_atomics");
873   c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_extended_atomics");
874   if (args->features.fp16) {
875      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp16");
876   }
877   if (args->features.fp64) {
878      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp64");
879      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_fp64");
880   }
881   if (args->features.int64) {
882      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cles_khr_int64");
883      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_int64");
884   }
885   if (args->features.images) {
886      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_images");
887   }
888   if (args->features.images_read_write) {
889      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_read_write_images");
890   }
891   if (args->features.images_write_3d) {
892      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_3d_image_writes");
893      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_3d_image_writes");
894   }
895   if (args->features.intel_subgroups) {
896      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_intel_subgroups");
897   }
898   if (args->features.subgroups) {
899      c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_subgroups");
900   }
901#endif
902
903   if (args->num_headers) {
904      ::llvm::SmallString<128> tmp_header_path;
905      ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
906      ::llvm::sys::path::append(tmp_header_path, "openclon12");
907
908      c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
909                                       clang::frontend::Quoted,
910                                       false, false);
911
912      for (size_t i = 0; i < args->num_headers; i++) {
913         auto path_copy = tmp_header_path;
914         ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
915         c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
916            ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
917      }
918   }
919
920   c->getPreprocessorOpts().addRemappedFile(
921           args->source.name,
922           ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
923
924   // Compile the code
925   clang::EmitLLVMOnlyAction act(&llvm_ctx);
926   if (!c->ExecuteAction(act)) {
927      clc_error(logger, "%sError executing LLVM compilation action.\n",
928                diag_log_str.c_str());
929      return {};
930   }
931
932   return act.takeModule();
933}
934
935static SPIRV::VersionNumber
936spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
937{
938   switch (version) {
939   case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
940   case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
941   case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
942   case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
943   case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
944#ifdef HAS_SPIRV_1_4
945   case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
946#endif
947   default:      return invalid_spirv_trans_version;
948   }
949}
950
951static int
952llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
953                  LLVMContext &context,
954                  const struct clc_compile_args *args,
955                  const struct clc_logger *logger,
956                  struct clc_binary *out_spirv)
957{
958   std::string log;
959
960   SPIRV::VersionNumber version =
961      spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
962   if (version == invalid_spirv_trans_version) {
963      clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
964      return -1;
965   }
966
967   const char *const *extensions = NULL;
968   if (args)
969      extensions = args->allowed_spirv_extensions;
970   if (!extensions) {
971      /* The SPIR-V parser doesn't handle all extensions */
972      static const char *default_extensions[] = {
973         "SPV_EXT_shader_atomic_float_add",
974         "SPV_EXT_shader_atomic_float_min_max",
975         "SPV_KHR_float_controls",
976         NULL,
977      };
978      extensions = default_extensions;
979   }
980
981   SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
982   for (int i = 0; extensions[i]; i++) {
983#define EXT(X) \
984      if (strcmp(#X, extensions[i]) == 0) \
985         ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
986#include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
987#undef EXT
988   }
989   SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
990
991#if LLVM_VERSION_MAJOR >= 13
992   /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
993   spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
994#endif
995
996   std::ostringstream spv_stream;
997   if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
998      clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
999                log.c_str());
1000      return -1;
1001   }
1002
1003   const std::string spv_out = spv_stream.str();
1004   out_spirv->size = spv_out.size();
1005   out_spirv->data = malloc(out_spirv->size);
1006   memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
1007
1008   return 0;
1009}
1010
1011int
1012clc_c_to_spir(const struct clc_compile_args *args,
1013              const struct clc_logger *logger,
1014              struct clc_binary *out_spir)
1015{
1016   clc_initialize_llvm();
1017
1018   LLVMContext llvm_ctx;
1019   llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1020                                         const_cast<clc_logger *>(logger));
1021
1022   auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1023   if (!mod)
1024      return -1;
1025
1026   ::llvm::SmallVector<char, 0> buffer;
1027   ::llvm::BitcodeWriter writer(buffer);
1028   writer.writeModule(*mod);
1029
1030   out_spir->size = buffer.size_in_bytes();
1031   out_spir->data = malloc(out_spir->size);
1032   memcpy(out_spir->data, buffer.data(), out_spir->size);
1033
1034   return 0;
1035}
1036
1037int
1038clc_c_to_spirv(const struct clc_compile_args *args,
1039               const struct clc_logger *logger,
1040               struct clc_binary *out_spirv)
1041{
1042   clc_initialize_llvm();
1043
1044   LLVMContext llvm_ctx;
1045   llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1046                                         const_cast<clc_logger *>(logger));
1047
1048   auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1049   if (!mod)
1050      return -1;
1051   return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
1052}
1053
1054int
1055clc_spir_to_spirv(const struct clc_binary *in_spir,
1056                  const struct clc_logger *logger,
1057                  struct clc_binary *out_spirv)
1058{
1059   clc_initialize_llvm();
1060
1061   LLVMContext llvm_ctx;
1062   llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1063                                         const_cast<clc_logger *>(logger));
1064
1065   ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
1066   auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
1067   if (!mod)
1068      return -1;
1069
1070   return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
1071}
1072
1073class SPIRVMessageConsumer {
1074public:
1075   SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
1076
1077   void operator()(spv_message_level_t level, const char *src,
1078                   const spv_position_t &pos, const char *msg)
1079   {
1080      if (level == SPV_MSG_INFO || level == SPV_MSG_DEBUG)
1081         return;
1082
1083      std::ostringstream message;
1084      message << "(file=" << src
1085              << ",line=" << pos.line
1086              << ",column=" << pos.column
1087              << ",index=" << pos.index
1088              << "): " << msg << "\n";
1089
1090      if (level == SPV_MSG_WARNING)
1091         clc_warning(logger, "%s", message.str().c_str());
1092      else
1093         clc_error(logger, "%s", message.str().c_str());
1094   }
1095
1096private:
1097   const struct clc_logger *logger;
1098};
1099
1100int
1101clc_link_spirv_binaries(const struct clc_linker_args *args,
1102                        const struct clc_logger *logger,
1103                        struct clc_binary *out_spirv)
1104{
1105   std::vector<std::vector<uint32_t>> binaries;
1106
1107   for (unsigned i = 0; i < args->num_in_objs; i++) {
1108      const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1109      std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1110      binaries.push_back(bin);
1111   }
1112
1113   SPIRVMessageConsumer msgconsumer(logger);
1114   spvtools::Context context(spirv_target);
1115   context.SetMessageConsumer(msgconsumer);
1116   spvtools::LinkerOptions options;
1117   options.SetAllowPartialLinkage(args->create_library);
1118   options.SetCreateLibrary(args->create_library);
1119   std::vector<uint32_t> linkingResult;
1120   spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1121   if (status != SPV_SUCCESS) {
1122      return -1;
1123   }
1124
1125   out_spirv->size = linkingResult.size() * 4;
1126   out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1127   memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1128
1129   return 0;
1130}
1131
1132int
1133clc_spirv_specialize(const struct clc_binary *in_spirv,
1134                     const struct clc_parsed_spirv *parsed_data,
1135                     const struct clc_spirv_specialization_consts *consts,
1136                     struct clc_binary *out_spirv)
1137{
1138   std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1139   for (unsigned i = 0; i < consts->num_specializations; ++i) {
1140      unsigned id = consts->specializations[i].id;
1141      auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1142         parsed_data->spec_constants + parsed_data->num_spec_constants,
1143         [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1144      assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1145
1146      std::vector<uint32_t> words;
1147      switch (parsed_spec_const->type) {
1148      case CLC_SPEC_CONSTANT_BOOL:
1149         words.push_back(consts->specializations[i].value.b);
1150         break;
1151      case CLC_SPEC_CONSTANT_INT32:
1152      case CLC_SPEC_CONSTANT_UINT32:
1153      case CLC_SPEC_CONSTANT_FLOAT:
1154         words.push_back(consts->specializations[i].value.u32);
1155         break;
1156      case CLC_SPEC_CONSTANT_INT16:
1157         words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1158         break;
1159      case CLC_SPEC_CONSTANT_INT8:
1160         words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1161         break;
1162      case CLC_SPEC_CONSTANT_UINT16:
1163         words.push_back((uint32_t)consts->specializations[i].value.u16);
1164         break;
1165      case CLC_SPEC_CONSTANT_UINT8:
1166         words.push_back((uint32_t)consts->specializations[i].value.u8);
1167         break;
1168      case CLC_SPEC_CONSTANT_DOUBLE:
1169      case CLC_SPEC_CONSTANT_INT64:
1170      case CLC_SPEC_CONSTANT_UINT64:
1171         words.resize(2);
1172         memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1173         break;
1174      case CLC_SPEC_CONSTANT_UNKNOWN:
1175         assert(0);
1176         break;
1177      }
1178
1179      ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1180      assert(ret.second);
1181   }
1182
1183   spvtools::Optimizer opt(spirv_target);
1184   opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1185
1186   std::vector<uint32_t> result;
1187   if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1188      return false;
1189
1190   out_spirv->size = result.size() * 4;
1191   out_spirv->data = malloc(out_spirv->size);
1192   memcpy(out_spirv->data, result.data(), out_spirv->size);
1193   return true;
1194}
1195
1196void
1197clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1198{
1199   spvtools::SpirvTools tools(spirv_target);
1200   const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1201   std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1202   std::string out;
1203   tools.Disassemble(bin, &out,
1204                     SPV_BINARY_TO_TEXT_OPTION_INDENT |
1205                     SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1206   fwrite(out.c_str(), out.size(), 1, f);
1207}
1208
1209void
1210clc_free_spir_binary(struct clc_binary *spir)
1211{
1212   free(spir->data);
1213}
1214
1215void
1216clc_free_spirv_binary(struct clc_binary *spvbin)
1217{
1218   free(spvbin->data);
1219}
1220
1221void
1222initialize_llvm_once(void)
1223{
1224   LLVMInitializeAllTargets();
1225   LLVMInitializeAllTargetInfos();
1226   LLVMInitializeAllTargetMCs();
1227   LLVMInitializeAllAsmParsers();
1228   LLVMInitializeAllAsmPrinters();
1229}
1230
1231std::once_flag initialize_llvm_once_flag;
1232
1233void
1234clc_initialize_llvm(void)
1235{
1236   std::call_once(initialize_llvm_once_flag,
1237                  []() { initialize_llvm_once(); });
1238}
1239