1#include "dxil_validator.h"
2
3#ifndef WIN32_LEAN_AND_MEAN
4#define WIN32_LEAN_AND_MEAN 1
5#endif
6
7#include <windows.h>
8#include <unknwn.h>
9
10#include "util/ralloc.h"
11#include "util/u_debug.h"
12#include "util/compiler.h"
13
14#include "dxcapi.h"
15
16#include <wrl/client.h>
17using Microsoft::WRL::ComPtr;
18
19struct dxil_validator {
20   HMODULE dxil_mod;
21   HMODULE dxcompiler_mod;
22
23   IDxcValidator *dxc_validator;
24   IDxcLibrary *dxc_library;
25   IDxcCompiler *dxc_compiler;
26
27   enum dxil_validator_version version;
28};
29
30extern "C" {
31extern IMAGE_DOS_HEADER __ImageBase;
32}
33
34static HMODULE
35load_dxil_mod()
36{
37   /* First, try to load DXIL.dll from the default search-path */
38   HMODULE mod = LoadLibraryA("DXIL.dll");
39   if (mod)
40      return mod;
41
42   /* If that fails, try to load it next to the current module, so we can
43    * ship DXIL.dll next to the GLon12 DLL.
44    */
45
46   char self_path[MAX_PATH];
47   uint32_t path_size = GetModuleFileNameA((HINSTANCE)&__ImageBase,
48                                           self_path, sizeof(self_path));
49   if (!path_size || path_size == sizeof(self_path)) {
50      debug_printf("DXIL: Unable to get path to self");
51      return NULL;
52   }
53
54   auto last_slash = strrchr(self_path, '\\');
55   if (!last_slash) {
56      debug_printf("DXIL: Unable to get path to self");
57      return NULL;
58   }
59
60   *(last_slash + 1) = '\0';
61   if (strcat_s(self_path, "DXIL.dll") != 0) {
62      debug_printf("DXIL: Unable to get path to DXIL.dll next to self");
63      return NULL;
64   }
65
66   return LoadLibraryA(self_path);
67}
68
69static IDxcValidator *
70create_dxc_validator(HMODULE dxil_mod)
71{
72   DxcCreateInstanceProc dxil_create_func =
73      (DxcCreateInstanceProc)GetProcAddress(dxil_mod, "DxcCreateInstance");
74   if (!dxil_create_func) {
75      debug_printf("DXIL: Failed to load DxcCreateInstance from DXIL.dll\n");
76      return NULL;
77   }
78
79   IDxcValidator *dxc_validator;
80   HRESULT hr = dxil_create_func(CLSID_DxcValidator,
81                                 IID_PPV_ARGS(&dxc_validator));
82   if (FAILED(hr)) {
83      debug_printf("DXIL: Failed to create validator\n");
84      return NULL;
85   }
86
87   return dxc_validator;
88}
89
90static enum dxil_validator_version
91get_validator_version(IDxcValidator *val)
92{
93   ComPtr<IDxcVersionInfo> version_info;
94   if (FAILED(val->QueryInterface(version_info.ReleaseAndGetAddressOf())))
95      return NO_DXIL_VALIDATION;
96
97   UINT32 major, minor;
98   if (FAILED(version_info->GetVersion(&major, &minor)))
99      return NO_DXIL_VALIDATION;
100
101   if (major == 1)
102      return (enum dxil_validator_version)(DXIL_VALIDATOR_1_0 + MIN2(minor, 7));
103   if (major > 1)
104      return DXIL_VALIDATOR_1_7;
105   return NO_DXIL_VALIDATION;
106}
107
108static uint64_t
109get_dll_version(HMODULE mod)
110{
111   WCHAR filename[MAX_PATH];
112   DWORD filename_length = GetModuleFileNameW(mod, filename, ARRAY_SIZE(filename));
113
114   if (filename_length == 0 || filename_length == ARRAY_SIZE(filename))
115      return 0;
116
117   DWORD version_handle = 0;
118   DWORD version_size = GetFileVersionInfoSizeW(filename, &version_handle);
119   if (version_size == 0)
120      return 0;
121
122   void *version_data = malloc(version_size);
123   if (!version_data)
124      return 0;
125
126   if (!GetFileVersionInfoW(filename, version_handle, version_size, version_data)) {
127      free(version_data);
128      return 0;
129   }
130
131   UINT value_size = 0;
132   VS_FIXEDFILEINFO *version_info = nullptr;
133   if (!VerQueryValueW(version_data, L"\\", reinterpret_cast<void **>(&version_info), &value_size) ||
134       !value_size ||
135       version_info->dwSignature != VS_FFI_SIGNATURE) {
136      free(version_data);
137      return 0;
138   }
139
140   uint64_t ret =
141      ((uint64_t)version_info->dwFileVersionMS << 32ull) |
142      (uint64_t)version_info->dwFileVersionLS;
143   free(version_data);
144   return ret;
145}
146
147static enum dxil_validator_version
148get_filtered_validator_version(HMODULE mod, enum dxil_validator_version raw)
149{
150   switch (raw) {
151   case DXIL_VALIDATOR_1_6: {
152      uint64_t dxil_version = get_dll_version(mod);
153      static constexpr uint64_t known_bad_version =
154         // 101.5.2005.60
155         (101ull << 48ull) | (5ull << 32ull) | (2005ull << 16ull) | 60ull;
156      if (dxil_version == known_bad_version)
157         return DXIL_VALIDATOR_1_5;
158      FALLTHROUGH;
159   }
160   default:
161      return raw;
162   }
163}
164
165struct dxil_validator *
166dxil_create_validator(const void *ctx)
167{
168   struct dxil_validator *val = rzalloc(ctx, struct dxil_validator);
169   if (!val)
170      return NULL;
171
172   /* Load DXIL.dll. This is a hard requirement on Windows, so we error
173    * out if this fails.
174    */
175   val->dxil_mod = load_dxil_mod();
176   if (!val->dxil_mod) {
177      debug_printf("DXIL: Failed to load DXIL.dll\n");
178      goto fail;
179   }
180
181   /* Create IDxcValidator. This is a hard requirement on Windows, so we
182    * error out if this fails.
183    */
184   val->dxc_validator = create_dxc_validator(val->dxil_mod);
185   if (!val->dxc_validator)
186      goto fail;
187
188   val->version = get_filtered_validator_version(
189      val->dxil_mod,
190      get_validator_version(val->dxc_validator));
191
192   /* Try to load dxcompiler.dll. This is just used for diagnostics, and
193    * will fail on most end-users install. So we do not error out if this
194    * fails.
195    */
196   val->dxcompiler_mod = LoadLibraryA("dxcompiler.dll");
197   if (val->dxcompiler_mod) {
198      /* If we managed to load dxcompiler.dll, but either don't find
199       * DxcCreateInstance, or fail to create IDxcLibrary or
200       * IDxcCompiler, this is a good indication that the user wants
201       * diagnostics, but something went wrong. Print warnings to help
202       * figuring out what's wrong, but do not treat it as an error.
203       */
204      DxcCreateInstanceProc compiler_create_func =
205         (DxcCreateInstanceProc)GetProcAddress(val->dxcompiler_mod,
206                                               "DxcCreateInstance");
207      if (!compiler_create_func) {
208         debug_printf("DXIL: Failed to load DxcCreateInstance from "
209                      "dxcompiler.dll\n");
210      } else {
211         if (FAILED(compiler_create_func(CLSID_DxcLibrary,
212                                         IID_PPV_ARGS(&val->dxc_library))))
213            debug_printf("DXIL: Unable to create IDxcLibrary instance\n");
214
215         if (FAILED(compiler_create_func(CLSID_DxcCompiler,
216                                         IID_PPV_ARGS(&val->dxc_compiler))))
217            debug_printf("DXIL: Unable to create IDxcCompiler instance\n");
218      }
219   }
220
221   return val;
222
223fail:
224   if (val->dxil_mod)
225      FreeLibrary(val->dxil_mod);
226
227   ralloc_free(val);
228   return NULL;
229}
230
231void
232dxil_destroy_validator(struct dxil_validator *val)
233{
234   /* if we have a validator, we have these */
235   val->dxc_validator->Release();
236   FreeLibrary(val->dxil_mod);
237
238   if (val->dxcompiler_mod) {
239      if (val->dxc_library)
240         val->dxc_library->Release();
241
242      if (val->dxc_compiler)
243         val->dxc_compiler->Release();
244
245      FreeLibrary(val->dxcompiler_mod);
246   }
247
248   ralloc_free(val);
249}
250
251class ShaderBlob : public IDxcBlob {
252public:
253   ShaderBlob(void *data, size_t size) :
254      m_data(data),
255      m_size(size)
256   {
257   }
258
259   LPVOID STDMETHODCALLTYPE
260   GetBufferPointer(void) override
261   {
262      return m_data;
263   }
264
265   SIZE_T STDMETHODCALLTYPE
266   GetBufferSize() override
267   {
268      return m_size;
269   }
270
271   HRESULT STDMETHODCALLTYPE
272   QueryInterface(REFIID, void **) override
273   {
274      return E_NOINTERFACE;
275   }
276
277   ULONG STDMETHODCALLTYPE
278   AddRef() override
279   {
280      return 1;
281   }
282
283   ULONG STDMETHODCALLTYPE
284   Release() override
285   {
286      return 0;
287   }
288
289   void *m_data;
290   size_t m_size;
291};
292
293bool
294dxil_validate_module(struct dxil_validator *val, void *data, size_t size, char **error)
295{
296   ShaderBlob source(data, size);
297
298   ComPtr<IDxcOperationResult> result;
299   val->dxc_validator->Validate(&source, DxcValidatorFlags_InPlaceEdit,
300                                &result);
301
302   HRESULT hr;
303   result->GetStatus(&hr);
304
305   if (FAILED(hr) && error) {
306      /* try to resolve error message */
307      *error = NULL;
308      if (!val->dxc_library) {
309         debug_printf("DXIL: validation failed, but lacking IDxcLibrary"
310                      "from dxcompiler.dll for proper diagnostics.\n");
311         return false;
312      }
313
314      ComPtr<IDxcBlobEncoding> blob, blob_utf8;
315
316      if (FAILED(result->GetErrorBuffer(&blob)))
317         fprintf(stderr, "DXIL: IDxcOperationResult::GetErrorBuffer() failed\n");
318      else if (FAILED(val->dxc_library->GetBlobAsUtf8(blob.Get(),
319                                                      blob_utf8.GetAddressOf())))
320         fprintf(stderr, "DXIL: IDxcLibrary::GetBlobAsUtf8() failed\n");
321      else {
322         char *str = reinterpret_cast<char *>(blob_utf8->GetBufferPointer());
323         str[blob_utf8->GetBufferSize() - 1] = 0;
324         *error = ralloc_strdup(val, str);
325      }
326   }
327
328   return SUCCEEDED(hr);
329}
330
331char *
332dxil_disasm_module(struct dxil_validator *val, void *data, size_t size)
333{
334   if (!val->dxc_compiler || !val->dxc_library) {
335      fprintf(stderr, "DXIL: disassembly requires IDxcLibrary and "
336              "IDxcCompiler from dxcompiler.dll\n");
337      return NULL;
338   }
339
340   ShaderBlob source(data, size);
341   ComPtr<IDxcBlobEncoding> blob, blob_utf8;
342
343   if (FAILED(val->dxc_compiler->Disassemble(&source, &blob))) {
344      fprintf(stderr, "DXIL: IDxcCompiler::Disassemble() failed\n");
345      return NULL;
346   }
347
348   if (FAILED(val->dxc_library->GetBlobAsUtf8(blob.Get(), blob_utf8.GetAddressOf()))) {
349      fprintf(stderr, "DXIL: IDxcLibrary::GetBlobAsUtf8() failed\n");
350      return NULL;
351   }
352
353   char *str = reinterpret_cast<char*>(blob_utf8->GetBufferPointer());
354   str[blob_utf8->GetBufferSize() - 1] = 0;
355   return ralloc_strdup(val, str);
356}
357
358enum dxil_validator_version
359dxil_get_validator_version(struct dxil_validator *val)
360{
361   return val->version;
362}
363