1 // Copyright 2020 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn/dawn_thread_dispatch_proc.h"
16 #include "dawn/webgpu_cpp.h"
17 #include "dawn_native/DawnNative.h"
18 #include "dawn_native/Instance.h"
19 #include "dawn_native/null/DeviceNull.h"
20 
21 #include <gtest/gtest.h>
22 #include <atomic>
23 #include <thread>
24 
25 class PerThreadProcTests : public testing::Test {
26   public:
PerThreadProcTests()27     PerThreadProcTests()
28         : mNativeInstance(dawn_native::InstanceBase::Create()),
29           mNativeAdapter(mNativeInstance.Get()) {
30     }
31     ~PerThreadProcTests() override = default;
32 
33   protected:
34     Ref<dawn_native::InstanceBase> mNativeInstance;
35     dawn_native::null::Adapter mNativeAdapter;
36 };
37 
38 // Test that procs can be set per thread. This test overrides deviceCreateBuffer with a dummy proc
39 // for each thread that increments a counter. Because each thread has their own proc and counter,
40 // there should be no data races. The per-thread procs also check that the current thread id is
41 // exactly equal to the expected thread id.
TEST_F(PerThreadProcTests, DispatchesPerThread)42 TEST_F(PerThreadProcTests, DispatchesPerThread) {
43     dawnProcSetProcs(&dawnThreadDispatchProcTable);
44 
45     // Threads will block on this atomic to be sure we set procs on both threads before
46     // either thread calls the procs.
47     std::atomic<bool> ready(false);
48 
49     static int threadACounter = 0;
50     static int threadBCounter = 0;
51 
52     static std::atomic<std::thread::id> threadIdA;
53     static std::atomic<std::thread::id> threadIdB;
54 
55     constexpr int kThreadATargetCount = 28347;
56     constexpr int kThreadBTargetCount = 40420;
57 
58     // Note: Acquire doesn't call reference or release.
59     wgpu::Device deviceA =
60         wgpu::Device::Acquire(reinterpret_cast<WGPUDevice>(mNativeAdapter.CreateDevice(nullptr)));
61 
62     wgpu::Device deviceB =
63         wgpu::Device::Acquire(reinterpret_cast<WGPUDevice>(mNativeAdapter.CreateDevice(nullptr)));
64 
65     std::thread threadA([&]() {
66         DawnProcTable procs = dawn_native::GetProcs();
67         procs.deviceCreateBuffer = [](WGPUDevice device,
68                                       WGPUBufferDescriptor const* descriptor) -> WGPUBuffer {
69             EXPECT_EQ(std::this_thread::get_id(), threadIdA);
70             threadACounter++;
71             return nullptr;
72         };
73         dawnProcSetPerThreadProcs(&procs);
74 
75         while (!ready) {
76         }  // Should be fast, so just spin.
77 
78         for (int i = 0; i < kThreadATargetCount; ++i) {
79             deviceA.CreateBuffer(nullptr);
80         }
81 
82         deviceA = nullptr;
83         dawnProcSetPerThreadProcs(nullptr);
84     });
85 
86     std::thread threadB([&]() {
87         DawnProcTable procs = dawn_native::GetProcs();
88         procs.deviceCreateBuffer = [](WGPUDevice device,
89                                       WGPUBufferDescriptor const* bufferDesc) -> WGPUBuffer {
90             EXPECT_EQ(std::this_thread::get_id(), threadIdB);
91             threadBCounter++;
92             return nullptr;
93         };
94         dawnProcSetPerThreadProcs(&procs);
95 
96         while (!ready) {
97         }  // Should be fast, so just spin.
98 
99         for (int i = 0; i < kThreadBTargetCount; ++i) {
100             deviceB.CreateBuffer(nullptr);
101         }
102 
103         deviceB = nullptr;
104         dawnProcSetPerThreadProcs(nullptr);
105     });
106 
107     threadIdA = threadA.get_id();
108     threadIdB = threadB.get_id();
109 
110     ready = true;
111     threadA.join();
112     threadB.join();
113 
114     EXPECT_EQ(threadACounter, kThreadATargetCount);
115     EXPECT_EQ(threadBCounter, kThreadBTargetCount);
116 
117     dawnProcSetProcs(nullptr);
118 }
119