1/**************************************************************************
2 *
3 * Copyright 2009-2013 VMware, Inc.
4 * All Rights Reserved.
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a
7 * copy of this software and associated documentation files (the
8 * "Software"), to deal in the Software without restriction, including
9 * without limitation the rights to use, copy, modify, merge, publish,
10 * distribute, sub license, and/or sell copies of the Software, and to
11 * permit persons to whom the Software is furnished to do so, subject to
12 * the following conditions:
13 *
14 * The above copyright notice and this permission notice (including the
15 * next paragraph) shall be included in all copies or substantial portions
16 * of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
21 * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
22 * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 *
26 **************************************************************************/
27
28#include <windows.h>
29#include <tlhelp32.h>
30
31#include "pipe/p_compiler.h"
32#include "util/u_debug.h"
33#include "stw_tls.h"
34
35static DWORD tlsIndex = TLS_OUT_OF_INDEXES;
36
37
38/**
39 * Static mutex to protect the access to g_pendingTlsData global and
40 * stw_tls_data::next member.
41 */
42static CRITICAL_SECTION g_mutex = {
43   (PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0
44};
45
46/**
47 * There is no way to invoke TlsSetValue for a different thread, so we
48 * temporarily put the thread data for non-current threads here.
49 */
50static struct stw_tls_data *g_pendingTlsData = NULL;
51
52
53static struct stw_tls_data *
54stw_tls_data_create(DWORD dwThreadId);
55
56static struct stw_tls_data *
57stw_tls_lookup_pending_data(DWORD dwThreadId);
58
59
60boolean
61stw_tls_init(void)
62{
63   tlsIndex = TlsAlloc();
64   if (tlsIndex == TLS_OUT_OF_INDEXES) {
65      return FALSE;
66   }
67
68   /*
69    * DllMain is called with DLL_THREAD_ATTACH only for threads created after
70    * the DLL is loaded by the process.  So enumerate and add our hook to all
71    * previously existing threads.
72    *
73    * XXX: Except for the current thread since it there is an explicit
74    * stw_tls_init_thread() call for it later on.
75    */
76   if (1) {
77      DWORD dwCurrentProcessId = GetCurrentProcessId();
78      DWORD dwCurrentThreadId = GetCurrentThreadId();
79      HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, dwCurrentProcessId);
80      if (hSnapshot != INVALID_HANDLE_VALUE) {
81         THREADENTRY32 te;
82         te.dwSize = sizeof te;
83         if (Thread32First(hSnapshot, &te)) {
84            do {
85               if (te.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) +
86                                sizeof te.th32OwnerProcessID) {
87                  if (te.th32OwnerProcessID == dwCurrentProcessId) {
88                     if (te.th32ThreadID != dwCurrentThreadId) {
89                        struct stw_tls_data *data;
90                        data = stw_tls_data_create(te.th32ThreadID);
91                        if (data) {
92                           EnterCriticalSection(&g_mutex);
93                           data->next = g_pendingTlsData;
94                           g_pendingTlsData = data;
95                           LeaveCriticalSection(&g_mutex);
96                        }
97                     }
98                  }
99               }
100               te.dwSize = sizeof te;
101            } while (Thread32Next(hSnapshot, &te));
102         }
103         CloseHandle(hSnapshot);
104      }
105   }
106
107   return TRUE;
108}
109
110
111/**
112 * Install windows hook for a given thread (not necessarily the current one).
113 */
114static struct stw_tls_data *
115stw_tls_data_create(DWORD dwThreadId)
116{
117   struct stw_tls_data *data;
118
119   if (0) {
120      debug_printf("%s(0x%04lx)\n", __FUNCTION__, dwThreadId);
121   }
122
123   data = calloc(1, sizeof *data);
124   if (!data) {
125      goto no_data;
126   }
127
128   data->dwThreadId = dwThreadId;
129
130   data->hCallWndProcHook = SetWindowsHookEx(WH_CALLWNDPROC,
131                                             stw_call_window_proc,
132                                             NULL,
133                                             dwThreadId);
134   if (data->hCallWndProcHook == NULL) {
135      goto no_hook;
136   }
137
138   return data;
139
140no_hook:
141   free(data);
142no_data:
143   return NULL;
144}
145
146/**
147 * Destroy the per-thread data/hook.
148 *
149 * It is important to remove all hooks when unloading our DLL, otherwise our
150 * hook function might be called after it is no longer there.
151 */
152static void
153stw_tls_data_destroy(struct stw_tls_data *data)
154{
155   assert(data);
156   if (!data) {
157      return;
158   }
159
160   if (0) {
161      debug_printf("%s(0x%04lx)\n", __FUNCTION__, data->dwThreadId);
162   }
163
164   if (data->hCallWndProcHook) {
165      UnhookWindowsHookEx(data->hCallWndProcHook);
166      data->hCallWndProcHook = NULL;
167   }
168
169   free(data);
170}
171
172boolean
173stw_tls_init_thread(void)
174{
175   struct stw_tls_data *data;
176
177   if (tlsIndex == TLS_OUT_OF_INDEXES) {
178      return FALSE;
179   }
180
181   data = stw_tls_data_create(GetCurrentThreadId());
182   if (!data) {
183      return FALSE;
184   }
185
186   TlsSetValue(tlsIndex, data);
187
188   return TRUE;
189}
190
191void
192stw_tls_cleanup_thread(void)
193{
194   struct stw_tls_data *data;
195
196   if (tlsIndex == TLS_OUT_OF_INDEXES) {
197      return;
198   }
199
200   data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
201   if (data) {
202      TlsSetValue(tlsIndex, NULL);
203   } else {
204      /* See if there this thread's data in on the pending list */
205      data = stw_tls_lookup_pending_data(GetCurrentThreadId());
206   }
207
208   if (data) {
209      stw_tls_data_destroy(data);
210   }
211}
212
213void
214stw_tls_cleanup(void)
215{
216   if (tlsIndex != TLS_OUT_OF_INDEXES) {
217      /*
218       * Destroy all items in g_pendingTlsData linked list.
219       */
220      EnterCriticalSection(&g_mutex);
221      while (g_pendingTlsData) {
222         struct stw_tls_data * data = g_pendingTlsData;
223         g_pendingTlsData = data->next;
224         stw_tls_data_destroy(data);
225      }
226      LeaveCriticalSection(&g_mutex);
227
228      TlsFree(tlsIndex);
229      tlsIndex = TLS_OUT_OF_INDEXES;
230   }
231}
232
233/*
234 * Search for the current thread in the g_pendingTlsData linked list.
235 *
236 * It will remove and return the node on success, or return NULL on failure.
237 */
238static struct stw_tls_data *
239stw_tls_lookup_pending_data(DWORD dwThreadId)
240{
241   struct stw_tls_data ** p_data;
242   struct stw_tls_data *data = NULL;
243
244   EnterCriticalSection(&g_mutex);
245   for (p_data = &g_pendingTlsData; *p_data; p_data = &(*p_data)->next) {
246      if ((*p_data)->dwThreadId == dwThreadId) {
247         data = *p_data;
248
249	 /*
250	  * Unlink the node.
251	  */
252         *p_data = data->next;
253         data->next = NULL;
254
255	 break;
256      }
257   }
258   LeaveCriticalSection(&g_mutex);
259
260   return data;
261}
262
263struct stw_tls_data *
264stw_tls_get_data(void)
265{
266   struct stw_tls_data *data;
267
268   if (tlsIndex == TLS_OUT_OF_INDEXES) {
269      return NULL;
270   }
271
272   data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
273   if (!data) {
274      DWORD dwCurrentThreadId = GetCurrentThreadId();
275
276      /*
277       * Search for the current thread in the g_pendingTlsData linked list.
278       */
279      data = stw_tls_lookup_pending_data(dwCurrentThreadId);
280
281      if (!data) {
282         /*
283          * This should be impossible now.
284          */
285	 assert(!"Failed to find thread data for thread id");
286
287         /*
288          * DllMain is called with DLL_THREAD_ATTACH only by threads created
289          * after the DLL is loaded by the process
290          */
291         data = stw_tls_data_create(dwCurrentThreadId);
292         if (!data) {
293            return NULL;
294         }
295      }
296
297      TlsSetValue(tlsIndex, data);
298   }
299
300   assert(data);
301   assert(data->dwThreadId = GetCurrentThreadId());
302   assert(data->next == NULL);
303
304   return data;
305}
306