1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include <stdio.h>
25#include <stdint.h>
26#include <stdexcept>
27
28#include <unknwn.h>
29#include <directx/d3d12.h>
30#include <dxgi1_4.h>
31#include <gtest/gtest.h>
32#include <wrl.h>
33#include <dxguids/dxguids.h>
34
35#include "util/u_debug.h"
36#include "clc_compiler.h"
37#include "compute_test.h"
38#include "dxil_validator.h"
39
40#include <spirv-tools/libspirv.hpp>
41
42#if (defined(_WIN32) && defined(_MSC_VER)) || D3D12_SDK_VERSION < 606
43inline D3D12_CPU_DESCRIPTOR_HANDLE
44GetCPUDescriptorHandleForHeapStart(ID3D12DescriptorHeap *heap)
45{
46   return heap->GetCPUDescriptorHandleForHeapStart();
47}
48inline D3D12_GPU_DESCRIPTOR_HANDLE
49GetGPUDescriptorHandleForHeapStart(ID3D12DescriptorHeap *heap)
50{
51   return heap->GetGPUDescriptorHandleForHeapStart();
52}
53inline D3D12_HEAP_PROPERTIES
54GetCustomHeapProperties(ID3D12Device *dev, D3D12_HEAP_TYPE type)
55{
56   return dev->GetCustomHeapProperties(0, type);
57}
58#else
59inline D3D12_CPU_DESCRIPTOR_HANDLE
60GetCPUDescriptorHandleForHeapStart(ID3D12DescriptorHeap *heap)
61{
62   D3D12_CPU_DESCRIPTOR_HANDLE ret;
63   heap->GetCPUDescriptorHandleForHeapStart(&ret);
64   return ret;
65}
66inline D3D12_GPU_DESCRIPTOR_HANDLE
67GetGPUDescriptorHandleForHeapStart(ID3D12DescriptorHeap *heap)
68{
69   D3D12_GPU_DESCRIPTOR_HANDLE ret;
70   heap->GetGPUDescriptorHandleForHeapStart(&ret);
71   return ret;
72}
73inline D3D12_HEAP_PROPERTIES
74GetCustomHeapProperties(ID3D12Device *dev, D3D12_HEAP_TYPE type)
75{
76   D3D12_HEAP_PROPERTIES ret;
77   dev->GetCustomHeapProperties(&ret, 0, type);
78   return ret;
79}
80#endif
81
82using std::runtime_error;
83using Microsoft::WRL::ComPtr;
84
85enum compute_test_debug_flags {
86   COMPUTE_DEBUG_EXPERIMENTAL_SHADERS = 1 << 0,
87   COMPUTE_DEBUG_USE_HW_D3D           = 1 << 1,
88   COMPUTE_DEBUG_OPTIMIZE_LIBCLC      = 1 << 2,
89   COMPUTE_DEBUG_SERIALIZE_LIBCLC     = 1 << 3,
90};
91
92static const struct debug_named_value compute_debug_options[] = {
93   { "experimental_shaders",  COMPUTE_DEBUG_EXPERIMENTAL_SHADERS, "Enable experimental shaders" },
94   { "use_hw_d3d",            COMPUTE_DEBUG_USE_HW_D3D,           "Use a hardware D3D device"   },
95   { "optimize_libclc",       COMPUTE_DEBUG_OPTIMIZE_LIBCLC,      "Optimize the clc_libclc before using it" },
96   { "serialize_libclc",      COMPUTE_DEBUG_SERIALIZE_LIBCLC,     "Serialize and deserialize the clc_libclc" },
97   DEBUG_NAMED_VALUE_END
98};
99
100DEBUG_GET_ONCE_FLAGS_OPTION(debug_compute, "COMPUTE_TEST_DEBUG", compute_debug_options, 0)
101
102static void warning_callback(void *priv, const char *msg)
103{
104   fprintf(stderr, "WARNING: %s\n", msg);
105}
106
107static void error_callback(void *priv, const char *msg)
108{
109   fprintf(stderr, "ERROR: %s\n", msg);
110}
111
112static const struct clc_logger logger = {
113   NULL,
114   error_callback,
115   warning_callback,
116};
117
118void
119ComputeTest::enable_d3d12_debug_layer()
120{
121   HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
122   if (!hD3D12Mod) {
123      fprintf(stderr, "D3D12: failed to load D3D12.DLL\n");
124      return;
125   }
126
127   typedef HRESULT(WINAPI * PFN_D3D12_GET_DEBUG_INTERFACE)(REFIID riid,
128                                                           void **ppFactory);
129   PFN_D3D12_GET_DEBUG_INTERFACE D3D12GetDebugInterface = (PFN_D3D12_GET_DEBUG_INTERFACE)GetProcAddress(hD3D12Mod, "D3D12GetDebugInterface");
130   if (!D3D12GetDebugInterface) {
131      fprintf(stderr, "D3D12: failed to load D3D12GetDebugInterface from D3D12.DLL\n");
132      return;
133   }
134
135   ID3D12Debug *debug;
136   if (FAILED(D3D12GetDebugInterface(__uuidof(ID3D12Debug), (void **)& debug))) {
137      fprintf(stderr, "D3D12: D3D12GetDebugInterface failed\n");
138      return;
139   }
140
141   debug->EnableDebugLayer();
142}
143
144IDXGIFactory4 *
145ComputeTest::get_dxgi_factory()
146{
147   static const GUID IID_IDXGIFactory4 = {
148      0x1bc6ea02, 0xef36, 0x464f,
149      { 0xbf, 0x0c, 0x21, 0xca, 0x39, 0xe5, 0x16, 0x8a }
150   };
151
152   typedef HRESULT(WINAPI * PFN_CREATE_DXGI_FACTORY)(REFIID riid,
153                                                     void **ppFactory);
154   PFN_CREATE_DXGI_FACTORY CreateDXGIFactory;
155
156   HMODULE hDXGIMod = LoadLibrary("DXGI.DLL");
157   if (!hDXGIMod)
158      throw runtime_error("Failed to load DXGI.DLL");
159
160   CreateDXGIFactory = (PFN_CREATE_DXGI_FACTORY)GetProcAddress(hDXGIMod, "CreateDXGIFactory");
161   if (!CreateDXGIFactory)
162      throw runtime_error("Failed to load CreateDXGIFactory from DXGI.DLL");
163
164   IDXGIFactory4 *factory = NULL;
165   HRESULT hr = CreateDXGIFactory(IID_IDXGIFactory4, (void **)&factory);
166   if (FAILED(hr))
167      throw runtime_error("CreateDXGIFactory failed");
168
169   return factory;
170}
171
172IDXGIAdapter1 *
173ComputeTest::choose_adapter(IDXGIFactory4 *factory)
174{
175   IDXGIAdapter1 *ret;
176
177   if (debug_get_option_debug_compute() & COMPUTE_DEBUG_USE_HW_D3D) {
178      for (unsigned i = 0; SUCCEEDED(factory->EnumAdapters1(i, &ret)); i++) {
179         DXGI_ADAPTER_DESC1 desc;
180         ret->GetDesc1(&desc);
181         if (!(desc.Flags & D3D_DRIVER_TYPE_SOFTWARE))
182            return ret;
183      }
184      throw runtime_error("Failed to enum hardware adapter");
185   } else {
186      if (FAILED(factory->EnumWarpAdapter(__uuidof(IDXGIAdapter1),
187         (void **)& ret)))
188         throw runtime_error("Failed to enum warp adapter");
189      return ret;
190   }
191}
192
193ID3D12Device *
194ComputeTest::create_device(IDXGIAdapter1 *adapter)
195{
196   typedef HRESULT(WINAPI *PFN_D3D12CREATEDEVICE)(IUnknown *, D3D_FEATURE_LEVEL, REFIID, void **);
197   PFN_D3D12CREATEDEVICE D3D12CreateDevice;
198
199   HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
200   if (!hD3D12Mod)
201      throw runtime_error("failed to load D3D12.DLL");
202
203   if (debug_get_option_debug_compute() & COMPUTE_DEBUG_EXPERIMENTAL_SHADERS) {
204      typedef HRESULT(WINAPI *PFN_D3D12ENABLEEXPERIMENTALFEATURES)(UINT, const IID *, void *, UINT *);
205      PFN_D3D12ENABLEEXPERIMENTALFEATURES D3D12EnableExperimentalFeatures;
206      D3D12EnableExperimentalFeatures = (PFN_D3D12ENABLEEXPERIMENTALFEATURES)
207         GetProcAddress(hD3D12Mod, "D3D12EnableExperimentalFeatures");
208      if (FAILED(D3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModels, NULL, NULL)))
209         throw runtime_error("failed to enable experimental shader models");
210   }
211
212   D3D12CreateDevice = (PFN_D3D12CREATEDEVICE)GetProcAddress(hD3D12Mod, "D3D12CreateDevice");
213   if (!D3D12CreateDevice)
214      throw runtime_error("failed to load D3D12CreateDevice from D3D12.DLL");
215
216   ID3D12Device *dev;
217   if (FAILED(D3D12CreateDevice(adapter, D3D_FEATURE_LEVEL_12_0,
218       __uuidof(ID3D12Device), (void **)& dev)))
219      throw runtime_error("D3D12CreateDevice failed");
220
221   return dev;
222}
223
224ComPtr<ID3D12RootSignature>
225ComputeTest::create_root_signature(const ComputeTest::Resources &resources)
226{
227   D3D12_ROOT_PARAMETER1 root_param;
228   root_param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
229   root_param.DescriptorTable.NumDescriptorRanges = resources.ranges.size();
230   root_param.DescriptorTable.pDescriptorRanges = resources.ranges.data();
231   root_param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
232
233   D3D12_ROOT_SIGNATURE_DESC1 root_sig_desc;
234   root_sig_desc.NumParameters = 1;
235   root_sig_desc.pParameters = &root_param;
236   root_sig_desc.NumStaticSamplers = 0;
237   root_sig_desc.pStaticSamplers = NULL;
238   root_sig_desc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
239
240   D3D12_VERSIONED_ROOT_SIGNATURE_DESC versioned_desc;
241   versioned_desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
242   versioned_desc.Desc_1_1 = root_sig_desc;
243
244   ID3DBlob *sig, *error;
245   if (FAILED(D3D12SerializeVersionedRootSignature(&versioned_desc,
246       &sig, &error)))
247      throw runtime_error("D3D12SerializeVersionedRootSignature failed");
248
249   ComPtr<ID3D12RootSignature> ret;
250   if (FAILED(dev->CreateRootSignature(0,
251       sig->GetBufferPointer(),
252       sig->GetBufferSize(),
253       __uuidof(ID3D12RootSignature),
254       (void **)& ret)))
255      throw runtime_error("CreateRootSignature failed");
256
257   return ret;
258}
259
260ComPtr<ID3D12PipelineState>
261ComputeTest::create_pipeline_state(ComPtr<ID3D12RootSignature> &root_sig,
262                                   const struct clc_dxil_object &dxil)
263{
264   D3D12_COMPUTE_PIPELINE_STATE_DESC pipeline_desc = { root_sig.Get() };
265   pipeline_desc.CS.pShaderBytecode = dxil.binary.data;
266   pipeline_desc.CS.BytecodeLength = dxil.binary.size;
267
268   ComPtr<ID3D12PipelineState> pipeline_state;
269   if (FAILED(dev->CreateComputePipelineState(&pipeline_desc,
270                                              __uuidof(ID3D12PipelineState),
271                                              (void **)& pipeline_state)))
272      throw runtime_error("Failed to create pipeline state");
273   return pipeline_state;
274}
275
276ComPtr<ID3D12Resource>
277ComputeTest::create_buffer(int size, D3D12_HEAP_TYPE heap_type)
278{
279   D3D12_RESOURCE_DESC desc;
280   desc.Format = DXGI_FORMAT_UNKNOWN;
281   desc.Alignment = D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT;
282   desc.Width = size;
283   desc.Height = 1;
284   desc.DepthOrArraySize = 1;
285   desc.MipLevels = 1;
286   desc.SampleDesc.Count = 1;
287   desc.SampleDesc.Quality = 0;
288   desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
289   desc.Flags = heap_type == D3D12_HEAP_TYPE_DEFAULT ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS : D3D12_RESOURCE_FLAG_NONE;
290   desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
291
292   D3D12_HEAP_PROPERTIES heap_pris = GetCustomHeapProperties(dev, heap_type);
293
294   ComPtr<ID3D12Resource> res;
295   if (FAILED(dev->CreateCommittedResource(&heap_pris,
296       D3D12_HEAP_FLAG_NONE, &desc, D3D12_RESOURCE_STATE_COMMON,
297       NULL, __uuidof(ID3D12Resource), (void **)&res)))
298      throw runtime_error("CreateCommittedResource failed");
299
300   return res;
301}
302
303ComPtr<ID3D12Resource>
304ComputeTest::create_upload_buffer_with_data(const void *data, size_t size)
305{
306   auto upload_res = create_buffer(size, D3D12_HEAP_TYPE_UPLOAD);
307
308   void *ptr = NULL;
309   D3D12_RANGE res_range = { 0, (SIZE_T)size };
310   if (FAILED(upload_res->Map(0, &res_range, (void **)&ptr)))
311      throw runtime_error("Failed to map upload-buffer");
312   assert(ptr);
313   memcpy(ptr, data, size);
314   upload_res->Unmap(0, &res_range);
315   return upload_res;
316}
317
318ComPtr<ID3D12Resource>
319ComputeTest::create_sized_buffer_with_data(size_t buffer_size,
320                                           const void *data,
321                                           size_t data_size)
322{
323   auto upload_res = create_upload_buffer_with_data(data, data_size);
324
325   auto res = create_buffer(buffer_size, D3D12_HEAP_TYPE_DEFAULT);
326   resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_DEST);
327   cmdlist->CopyBufferRegion(res.Get(), 0, upload_res.Get(), 0, data_size);
328   resource_barrier(res, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_COMMON);
329   execute_cmdlist();
330
331   return res;
332}
333
334void
335ComputeTest::get_buffer_data(ComPtr<ID3D12Resource> res,
336                             void *buf, size_t size)
337{
338   auto readback_res = create_buffer(align(size, 4), D3D12_HEAP_TYPE_READBACK);
339   resource_barrier(res, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_SOURCE);
340   cmdlist->CopyResource(readback_res.Get(), res.Get());
341   resource_barrier(res, D3D12_RESOURCE_STATE_COPY_SOURCE, D3D12_RESOURCE_STATE_COMMON);
342   execute_cmdlist();
343
344   void *ptr = NULL;
345   D3D12_RANGE res_range = { 0, size };
346   if (FAILED(readback_res->Map(0, &res_range, &ptr)))
347      throw runtime_error("Failed to map readback-buffer");
348
349   memcpy(buf, ptr, size);
350
351   D3D12_RANGE empty_range = { 0, 0 };
352   readback_res->Unmap(0, &empty_range);
353}
354
355void
356ComputeTest::resource_barrier(ComPtr<ID3D12Resource> &res,
357                              D3D12_RESOURCE_STATES state_before,
358                              D3D12_RESOURCE_STATES state_after)
359{
360   D3D12_RESOURCE_BARRIER barrier;
361   barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
362   barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
363   barrier.Transition.pResource = res.Get();
364   barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
365   barrier.Transition.StateBefore = state_before;
366   barrier.Transition.StateAfter = state_after;
367   cmdlist->ResourceBarrier(1, &barrier);
368}
369
370void
371ComputeTest::execute_cmdlist()
372{
373   if (FAILED(cmdlist->Close()))
374      throw runtime_error("Closing ID3D12GraphicsCommandList failed");
375
376   ID3D12CommandList *cmdlists[] = { cmdlist };
377   cmdqueue->ExecuteCommandLists(1, cmdlists);
378   cmdqueue_fence->SetEventOnCompletion(fence_value, event);
379   cmdqueue->Signal(cmdqueue_fence, fence_value);
380   fence_value++;
381   WaitForSingleObject(event, INFINITE);
382
383   if (FAILED(cmdalloc->Reset()))
384      throw runtime_error("resetting ID3D12CommandAllocator failed");
385
386   if (FAILED(cmdlist->Reset(cmdalloc, NULL)))
387      throw runtime_error("resetting ID3D12GraphicsCommandList failed");
388}
389
390void
391ComputeTest::create_uav_buffer(ComPtr<ID3D12Resource> res,
392                               size_t width, size_t byte_stride,
393                               D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
394{
395   D3D12_UNORDERED_ACCESS_VIEW_DESC uav_desc;
396   uav_desc.Format = DXGI_FORMAT_R32_TYPELESS;
397   uav_desc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
398   uav_desc.Buffer.FirstElement = 0;
399   uav_desc.Buffer.NumElements = DIV_ROUND_UP(width * byte_stride, 4);
400   uav_desc.Buffer.StructureByteStride = 0;
401   uav_desc.Buffer.CounterOffsetInBytes = 0;
402   uav_desc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
403
404   dev->CreateUnorderedAccessView(res.Get(), NULL, &uav_desc, cpu_handle);
405}
406
407void
408ComputeTest::create_cbv(ComPtr<ID3D12Resource> res, size_t size,
409                        D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle)
410{
411   D3D12_CONSTANT_BUFFER_VIEW_DESC cbv_desc;
412   cbv_desc.BufferLocation = res ? res->GetGPUVirtualAddress() : 0;
413   cbv_desc.SizeInBytes = size;
414
415   dev->CreateConstantBufferView(&cbv_desc, cpu_handle);
416}
417
418ComPtr<ID3D12Resource>
419ComputeTest::add_uav_resource(ComputeTest::Resources &resources,
420                              unsigned spaceid, unsigned resid,
421                              const void *data, size_t num_elems,
422                              size_t elem_size)
423{
424   size_t size = align(elem_size * num_elems, 4);
425   D3D12_CPU_DESCRIPTOR_HANDLE handle;
426   ComPtr<ID3D12Resource> res;
427   handle = GetCPUDescriptorHandleForHeapStart(uav_heap);
428   handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
429
430   if (size) {
431      if (data)
432         res = create_buffer_with_data(data, size);
433      else
434         res = create_buffer(size, D3D12_HEAP_TYPE_DEFAULT);
435
436      resource_barrier(res, D3D12_RESOURCE_STATE_COMMON,
437                       D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
438   }
439   create_uav_buffer(res, num_elems, elem_size, handle);
440   resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_UAV, spaceid, resid);
441   return res;
442}
443
444ComPtr<ID3D12Resource>
445ComputeTest::add_cbv_resource(ComputeTest::Resources &resources,
446                              unsigned spaceid, unsigned resid,
447                              const void *data, size_t size)
448{
449   unsigned aligned_size = align(size, 256);
450   D3D12_CPU_DESCRIPTOR_HANDLE handle;
451   ComPtr<ID3D12Resource> res;
452   handle = GetCPUDescriptorHandleForHeapStart(uav_heap);
453   handle = offset_cpu_handle(handle, resources.descs.size() * uav_heap_incr);
454
455   if (size) {
456     assert(data);
457     res = create_sized_buffer_with_data(aligned_size, data, size);
458   }
459   create_cbv(res, aligned_size, handle);
460   resources.add(res, D3D12_DESCRIPTOR_RANGE_TYPE_CBV, spaceid, resid);
461   return res;
462}
463
464void
465ComputeTest::run_shader_with_raw_args(Shader shader,
466                                      const CompileArgs &compile_args,
467                                      const std::vector<RawShaderArg *> &args)
468{
469   if (args.size() < 1)
470      throw runtime_error("no inputs");
471
472   static HMODULE hD3D12Mod = LoadLibrary("D3D12.DLL");
473   if (!hD3D12Mod)
474      throw runtime_error("Failed to load D3D12.DLL");
475
476   D3D12SerializeVersionedRootSignature = (PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE)GetProcAddress(hD3D12Mod, "D3D12SerializeVersionedRootSignature");
477
478   if (args.size() != shader.dxil->kernel->num_args)
479      throw runtime_error("incorrect number of inputs");
480
481   struct clc_runtime_kernel_conf conf = { 0 };
482
483   // Older WARP and some hardware doesn't support int64, so for these tests, unconditionally lower away int64
484   // A more complex runtime can be smarter about detecting when this needs to be done
485   conf.lower_bit_size = 64;
486
487   if (!shader.dxil->metadata.local_size[0])
488      conf.local_size[0] = compile_args.x;
489   else
490      conf.local_size[0] = shader.dxil->metadata.local_size[0];
491
492   if (!shader.dxil->metadata.local_size[1])
493      conf.local_size[1] = compile_args.y;
494   else
495      conf.local_size[1] = shader.dxil->metadata.local_size[1];
496
497   if (!shader.dxil->metadata.local_size[2])
498      conf.local_size[2] = compile_args.z;
499   else
500      conf.local_size[2] = shader.dxil->metadata.local_size[2];
501
502   if (compile_args.x % conf.local_size[0] ||
503       compile_args.y % conf.local_size[1] ||
504       compile_args.z % conf.local_size[2])
505      throw runtime_error("invalid global size must be a multiple of local size");
506
507   std::vector<struct clc_runtime_arg_info> argsinfo(args.size());
508
509   conf.args = argsinfo.data();
510   conf.support_global_work_id_offsets =
511      compile_args.work_props.global_offset_x != 0 ||
512      compile_args.work_props.global_offset_y != 0 ||
513      compile_args.work_props.global_offset_z != 0;
514   conf.support_workgroup_id_offsets =
515      compile_args.work_props.group_id_offset_x != 0 ||
516      compile_args.work_props.group_id_offset_y != 0 ||
517      compile_args.work_props.group_id_offset_z != 0;
518
519   for (unsigned i = 0; i < shader.dxil->kernel->num_args; ++i) {
520      RawShaderArg *arg = args[i];
521      size_t size = arg->get_elem_size() * arg->get_num_elems();
522
523      switch (shader.dxil->kernel->args[i].address_qualifier) {
524      case CLC_KERNEL_ARG_ADDRESS_LOCAL:
525         argsinfo[i].localptr.size = size;
526         break;
527      default:
528         break;
529      }
530   }
531
532   configure(shader, &conf);
533   validate(shader);
534
535   std::shared_ptr<struct clc_dxil_object> &dxil = shader.dxil;
536
537   std::vector<uint8_t> argsbuf(dxil->metadata.kernel_inputs_buf_size);
538   std::vector<ComPtr<ID3D12Resource>> argres(shader.dxil->kernel->num_args);
539   clc_work_properties_data work_props = compile_args.work_props;
540   if (!conf.support_workgroup_id_offsets) {
541      work_props.group_count_total_x = compile_args.x / conf.local_size[0];
542      work_props.group_count_total_y = compile_args.y / conf.local_size[1];
543      work_props.group_count_total_z = compile_args.z / conf.local_size[2];
544   }
545   if (work_props.work_dim == 0)
546      work_props.work_dim = 3;
547   Resources resources;
548
549   for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
550      RawShaderArg *arg = args[i];
551      size_t size = arg->get_elem_size() * arg->get_num_elems();
552      void *slot = argsbuf.data() + dxil->metadata.args[i].offset;
553
554      switch (dxil->kernel->args[i].address_qualifier) {
555      case CLC_KERNEL_ARG_ADDRESS_CONSTANT:
556      case CLC_KERNEL_ARG_ADDRESS_GLOBAL: {
557         assert(dxil->metadata.args[i].size == sizeof(uint64_t));
558         uint64_t *ptr_slot = (uint64_t *)slot;
559         if (arg->get_data())
560            *ptr_slot = (uint64_t)dxil->metadata.args[i].globconstptr.buf_id << 32;
561         else
562            *ptr_slot = ~0ull;
563         break;
564      }
565      case CLC_KERNEL_ARG_ADDRESS_LOCAL: {
566         assert(dxil->metadata.args[i].size == sizeof(uint64_t));
567         uint64_t *ptr_slot = (uint64_t *)slot;
568         *ptr_slot = dxil->metadata.args[i].localptr.sharedmem_offset;
569         break;
570      }
571      case CLC_KERNEL_ARG_ADDRESS_PRIVATE: {
572         assert(size == dxil->metadata.args[i].size);
573         memcpy(slot, arg->get_data(), size);
574         break;
575      }
576      default:
577         assert(0);
578      }
579   }
580
581   for (unsigned i = 0; i < dxil->kernel->num_args; ++i) {
582      RawShaderArg *arg = args[i];
583
584      if (dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL ||
585          dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT) {
586         argres[i] = add_uav_resource(resources, 0,
587                                      dxil->metadata.args[i].globconstptr.buf_id,
588                                      arg->get_data(), arg->get_num_elems(),
589                                      arg->get_elem_size());
590      }
591   }
592
593   if (dxil->metadata.printf.uav_id > 0)
594      add_uav_resource(resources, 0, dxil->metadata.printf.uav_id, NULL, 1024 * 1024 / 4, 4);
595
596   for (unsigned i = 0; i < dxil->metadata.num_consts; ++i)
597      add_uav_resource(resources, 0, dxil->metadata.consts[i].uav_id,
598                       dxil->metadata.consts[i].data,
599                       dxil->metadata.consts[i].size / 4, 4);
600
601   if (argsbuf.size())
602      add_cbv_resource(resources, 0, dxil->metadata.kernel_inputs_cbv_id,
603                       argsbuf.data(), argsbuf.size());
604
605   add_cbv_resource(resources, 0, dxil->metadata.work_properties_cbv_id,
606                    &work_props, sizeof(work_props));
607
608   auto root_sig = create_root_signature(resources);
609   auto pipeline_state = create_pipeline_state(root_sig, *dxil);
610
611   cmdlist->SetDescriptorHeaps(1, &uav_heap);
612   cmdlist->SetComputeRootSignature(root_sig.Get());
613   cmdlist->SetComputeRootDescriptorTable(0, GetGPUDescriptorHandleForHeapStart(uav_heap));
614   cmdlist->SetPipelineState(pipeline_state.Get());
615
616   cmdlist->Dispatch(compile_args.x / conf.local_size[0],
617                     compile_args.y / conf.local_size[1],
618                     compile_args.z / conf.local_size[2]);
619
620   for (auto &range : resources.ranges) {
621      if (range.RangeType == D3D12_DESCRIPTOR_RANGE_TYPE_UAV) {
622         for (unsigned i = range.OffsetInDescriptorsFromTableStart;
623              i < range.NumDescriptors; i++) {
624            if (!resources.descs[i].Get())
625               continue;
626
627            resource_barrier(resources.descs[i],
628                             D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
629                             D3D12_RESOURCE_STATE_COMMON);
630         }
631      }
632   }
633
634   execute_cmdlist();
635
636   for (unsigned i = 0; i < args.size(); i++) {
637      if (!(args[i]->get_direction() & SHADER_ARG_OUTPUT))
638         continue;
639
640      assert(dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL);
641      get_buffer_data(argres[i], args[i]->get_data(),
642                      args[i]->get_elem_size() * args[i]->get_num_elems());
643   }
644
645   ComPtr<ID3D12InfoQueue> info_queue;
646   dev->QueryInterface(info_queue.ReleaseAndGetAddressOf());
647   if (info_queue)
648   {
649      EXPECT_EQ(0, info_queue->GetNumStoredMessages());
650      for (unsigned i = 0; i < info_queue->GetNumStoredMessages(); ++i) {
651         SIZE_T message_size = 0;
652         info_queue->GetMessageA(i, nullptr, &message_size);
653         D3D12_MESSAGE* message = (D3D12_MESSAGE*)malloc(message_size);
654         info_queue->GetMessageA(i, message, &message_size);
655         FAIL() << message->pDescription;
656         free(message);
657      }
658   }
659}
660
661void
662ComputeTest::SetUp()
663{
664   static struct clc_libclc *compiler_ctx_g = nullptr;
665
666   if (!compiler_ctx_g) {
667      clc_libclc_dxil_options options = { };
668      options.optimize = (debug_get_option_debug_compute() & COMPUTE_DEBUG_OPTIMIZE_LIBCLC) != 0;
669
670      compiler_ctx_g = clc_libclc_new_dxil(&logger, &options);
671      if (!compiler_ctx_g)
672         throw runtime_error("failed to create CLC compiler context");
673
674      if (debug_get_option_debug_compute() & COMPUTE_DEBUG_SERIALIZE_LIBCLC) {
675         void *serialized = nullptr;
676         size_t serialized_size = 0;
677         clc_libclc_serialize(compiler_ctx_g, &serialized, &serialized_size);
678         if (!serialized)
679            throw runtime_error("failed to serialize CLC compiler context");
680
681         clc_free_libclc(compiler_ctx_g);
682         compiler_ctx_g = nullptr;
683
684         compiler_ctx_g = clc_libclc_deserialize(serialized, serialized_size);
685         if (!compiler_ctx_g)
686            throw runtime_error("failed to deserialize CLC compiler context");
687
688         clc_libclc_free_serialized(serialized);
689      }
690   }
691   compiler_ctx = compiler_ctx_g;
692
693   enable_d3d12_debug_layer();
694
695   factory = get_dxgi_factory();
696   if (!factory)
697      throw runtime_error("failed to create DXGI factory");
698
699   adapter = choose_adapter(factory);
700   if (!adapter)
701      throw runtime_error("failed to choose adapter");
702
703   dev = create_device(adapter);
704   if (!dev)
705      throw runtime_error("failed to create device");
706
707   if (FAILED(dev->CreateFence(0, D3D12_FENCE_FLAG_NONE,
708                               __uuidof(cmdqueue_fence),
709                               (void **)&cmdqueue_fence)))
710      throw runtime_error("failed to create fence\n");
711
712   D3D12_COMMAND_QUEUE_DESC queue_desc;
713   queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
714   queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL;
715   queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
716   queue_desc.NodeMask = 0;
717   if (FAILED(dev->CreateCommandQueue(&queue_desc,
718                                      __uuidof(cmdqueue),
719                                      (void **)&cmdqueue)))
720      throw runtime_error("failed to create command queue");
721
722   if (FAILED(dev->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE,
723             __uuidof(cmdalloc), (void **)&cmdalloc)))
724      throw runtime_error("failed to create command allocator");
725
726   if (FAILED(dev->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE,
727             cmdalloc, NULL, __uuidof(cmdlist), (void **)&cmdlist)))
728      throw runtime_error("failed to create command list");
729
730   D3D12_DESCRIPTOR_HEAP_DESC heap_desc;
731   heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
732   heap_desc.NumDescriptors = 1000;
733   heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
734   heap_desc.NodeMask = 0;
735   if (FAILED(dev->CreateDescriptorHeap(&heap_desc,
736       __uuidof(uav_heap), (void **)&uav_heap)))
737      throw runtime_error("failed to create descriptor heap");
738
739   uav_heap_incr = dev->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
740
741   event = CreateEvent(NULL, FALSE, FALSE, NULL);
742   if (!event)
743      throw runtime_error("Failed to create event");
744   fence_value = 1;
745}
746
747void
748ComputeTest::TearDown()
749{
750   CloseHandle(event);
751
752   uav_heap->Release();
753   cmdlist->Release();
754   cmdalloc->Release();
755   cmdqueue->Release();
756   cmdqueue_fence->Release();
757   dev->Release();
758   adapter->Release();
759   factory->Release();
760}
761
762PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE ComputeTest::D3D12SerializeVersionedRootSignature;
763
764bool
765validate_module(const struct clc_dxil_object &dxil)
766{
767   struct dxil_validator *val = dxil_create_validator(NULL);
768   char *err;
769   bool res = dxil_validate_module(val, dxil.binary.data,
770                                   dxil.binary.size, &err);
771   if (!res && err)
772      fprintf(stderr, "D3D12: validation failed: %s", err);
773
774   dxil_destroy_validator(val);
775   return res;
776}
777
778static void
779dump_blob(const char *path, const struct clc_dxil_object &dxil)
780{
781   FILE *fp = fopen(path, "wb");
782   if (fp) {
783      fwrite(dxil.binary.data, 1, dxil.binary.size, fp);
784      fclose(fp);
785      printf("D3D12: wrote '%s'...\n", path);
786   }
787}
788
789ComputeTest::Shader
790ComputeTest::compile(const std::vector<const char *> &sources,
791                     const std::vector<const char *> &compile_args,
792                     bool create_library)
793{
794   struct clc_compile_args args = {
795   };
796   args.args = compile_args.data();
797   args.num_args = (unsigned)compile_args.size();
798   ComputeTest::Shader shader;
799
800   std::vector<Shader> shaders;
801
802   args.source.name = "obj.cl";
803
804   for (unsigned i = 0; i < sources.size(); i++) {
805      args.source.value = sources[i];
806
807      clc_binary spirv{};
808      if (!clc_compile_c_to_spirv(&args, &logger, &spirv))
809         throw runtime_error("failed to compile object!");
810
811      Shader shader;
812      shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
813         {
814            clc_free_spirv(spirv);
815            delete spirv;
816         });
817      shaders.push_back(shader);
818   }
819
820   if (shaders.size() == 1 && create_library)
821      return shaders[0];
822
823   return link(shaders, create_library);
824}
825
826ComputeTest::Shader
827ComputeTest::link(const std::vector<Shader> &sources,
828                  bool create_library)
829{
830   std::vector<const clc_binary*> objs;
831   for (auto& source : sources)
832      objs.push_back(&*source.obj);
833
834   struct clc_linker_args link_args = {};
835   link_args.in_objs = objs.data();
836   link_args.num_in_objs = (unsigned)objs.size();
837   link_args.create_library = create_library;
838   clc_binary spirv{};
839   if (!clc_link_spirv(&link_args, &logger, &spirv))
840      throw runtime_error("failed to link objects!");
841
842   ComputeTest::Shader shader;
843   shader.obj = std::shared_ptr<clc_binary>(new clc_binary(spirv), [](clc_binary *spirv)
844      {
845         clc_free_spirv(spirv);
846         delete spirv;
847      });
848   if (!link_args.create_library)
849      configure(shader, NULL);
850
851   return shader;
852}
853
854ComputeTest::Shader
855ComputeTest::assemble(const char *source)
856{
857   spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
858   std::vector<uint32_t> binary;
859   if (!tools.Assemble(source, strlen(source), &binary))
860      throw runtime_error("failed to assemble");
861
862   ComputeTest::Shader shader;
863   shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary *spirv)
864      {
865         free(spirv->data);
866         delete spirv;
867      });
868   shader.obj->size = binary.size() * 4;
869   shader.obj->data = malloc(shader.obj->size);
870   memcpy(shader.obj->data, binary.data(), shader.obj->size);
871
872   configure(shader, NULL);
873
874   return shader;
875}
876
877void
878ComputeTest::configure(Shader &shader,
879                       const struct clc_runtime_kernel_conf *conf)
880{
881   if (!shader.metadata) {
882      shader.metadata = std::shared_ptr<clc_parsed_spirv>(new clc_parsed_spirv{}, [](clc_parsed_spirv *metadata)
883         {
884            clc_free_parsed_spirv(metadata);
885            delete metadata;
886         });
887      if (!clc_parse_spirv(shader.obj.get(), NULL, shader.metadata.get()))
888         throw runtime_error("failed to parse spirv!");
889   }
890
891   std::unique_ptr<clc_dxil_object> dxil(new clc_dxil_object{});
892   if (!clc_spirv_to_dxil(compiler_ctx, shader.obj.get(), shader.metadata.get(), "main_test", conf, nullptr, &logger, dxil.get()))
893      throw runtime_error("failed to compile kernel!");
894   shader.dxil = std::shared_ptr<clc_dxil_object>(dxil.release(), [](clc_dxil_object *dxil)
895      {
896         clc_free_dxil_object(dxil);
897         delete dxil;
898      });
899}
900
901void
902ComputeTest::validate(ComputeTest::Shader &shader)
903{
904   dump_blob("unsigned.cso", *shader.dxil);
905   if (!validate_module(*shader.dxil))
906      throw runtime_error("failed to validate module!");
907
908   dump_blob("signed.cso", *shader.dxil);
909}
910