1 // Copyright 2019 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tests/DawnTest.h"
16
17 #include "utils/WGPUHelpers.h"
18
19 #include <array>
20
21 class ComputeSharedMemoryTests : public DawnTest {
22 public:
23 static constexpr uint32_t kInstances = 11;
24
25 void BasicTest(const char* shader);
26 };
27
BasicTest(const char* shader)28 void ComputeSharedMemoryTests::BasicTest(const char* shader) {
29 // Set up shader and pipeline
30 auto module = utils::CreateShaderModule(device, shader);
31
32 wgpu::ComputePipelineDescriptor csDesc;
33 csDesc.compute.module = module;
34 csDesc.compute.entryPoint = "main";
35 wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
36
37 // Set up dst storage buffer
38 wgpu::BufferDescriptor dstDesc;
39 dstDesc.size = sizeof(uint32_t);
40 dstDesc.usage =
41 wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
42 wgpu::Buffer dst = device.CreateBuffer(&dstDesc);
43
44 const uint32_t zero = 0;
45 queue.WriteBuffer(dst, 0, &zero, sizeof(zero));
46
47 // Set up bind group and issue dispatch
48 wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
49 {
50 {0, dst, 0, sizeof(uint32_t)},
51 });
52
53 wgpu::CommandBuffer commands;
54 {
55 wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
56 wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
57 pass.SetPipeline(pipeline);
58 pass.SetBindGroup(0, bindGroup);
59 pass.Dispatch(1);
60 pass.EndPass();
61
62 commands = encoder.Finish();
63 }
64
65 queue.Submit(1, &commands);
66
67 const uint32_t expected = kInstances;
68 EXPECT_BUFFER_U32_RANGE_EQ(&expected, dst, 0, 1);
69 }
70
71 // Basic shared memory test
TEST_P(ComputeSharedMemoryTests, Basic)72 TEST_P(ComputeSharedMemoryTests, Basic) {
73 BasicTest(R"(
74 let kTileSize : u32 = 4u;
75 let kInstances : u32 = 11u;
76
77 [[block]] struct Dst {
78 x : u32;
79 };
80
81 [[group(0), binding(0)]] var<storage, write> dst : Dst;
82 var<workgroup> tmp : u32;
83
84 [[stage(compute), workgroup_size(4,4,1)]]
85 fn main([[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
86 let index : u32 = LocalInvocationID.y * kTileSize + LocalInvocationID.x;
87 if (index == 0u) {
88 tmp = 0u;
89 }
90 workgroupBarrier();
91 for (var i : u32 = 0u; i < kInstances; i = i + 1u) {
92 if (i == index) {
93 tmp = tmp + 1u;
94 }
95 workgroupBarrier();
96 }
97 if (index == 0u) {
98 dst.x = tmp;
99 }
100 })");
101 }
102
103 // Test using assorted types in workgroup memory. MSL lacks constructors
104 // for matrices in threadgroup memory. Basic test that reading and
105 // writing a matrix in workgroup memory works.
TEST_P(ComputeSharedMemoryTests, AssortedTypes)106 TEST_P(ComputeSharedMemoryTests, AssortedTypes) {
107 wgpu::ComputePipelineDescriptor csDesc;
108 csDesc.compute.module = utils::CreateShaderModule(device, R"(
109 struct StructValues {
110 m: mat2x2<f32>;
111 };
112
113 [[block]] struct Dst {
114 d_struct : StructValues;
115 d_matrix : mat2x2<f32>;
116 d_array : array<u32, 4>;
117 d_vector : vec4<f32>;
118 };
119
120 [[group(0), binding(0)]] var<storage, write> dst : Dst;
121
122 var<workgroup> wg_struct : StructValues;
123 var<workgroup> wg_matrix : mat2x2<f32>;
124 var<workgroup> wg_array : array<u32, 4>;
125 var<workgroup> wg_vector : vec4<f32>;
126
127 [[stage(compute), workgroup_size(4,1,1)]]
128 fn main([[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
129
130 let i = 4u * LocalInvocationID.x;
131 if (LocalInvocationID.x == 0u) {
132 wg_struct.m = mat2x2<f32>(
133 vec2<f32>(f32(i), f32(i + 1u)),
134 vec2<f32>(f32(i + 2u), f32(i + 3u)));
135 } elseif (LocalInvocationID.x == 1u) {
136 wg_matrix = mat2x2<f32>(
137 vec2<f32>(f32(i), f32(i + 1u)),
138 vec2<f32>(f32(i + 2u), f32(i + 3u)));
139 } elseif (LocalInvocationID.x == 2u) {
140 wg_array[0u] = i;
141 wg_array[1u] = i + 1u;
142 wg_array[2u] = i + 2u;
143 wg_array[3u] = i + 3u;
144 } elseif (LocalInvocationID.x == 3u) {
145 wg_vector = vec4<f32>(
146 f32(i), f32(i + 1u), f32(i + 2u), f32(i + 3u));
147 }
148
149 workgroupBarrier();
150
151 if (LocalInvocationID.x == 0u) {
152 dst.d_struct = wg_struct;
153 dst.d_matrix = wg_matrix;
154 dst.d_array = wg_array;
155 dst.d_vector = wg_vector;
156 }
157 }
158 )");
159 csDesc.compute.entryPoint = "main";
160 wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
161
162 // Set up dst storage buffer
163 wgpu::BufferDescriptor dstDesc;
164 dstDesc.size = 64;
165 dstDesc.usage =
166 wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
167 wgpu::Buffer dst = device.CreateBuffer(&dstDesc);
168
169 // Set up bind group and issue dispatch
170 wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
171 {
172 {0, dst},
173 });
174
175 wgpu::CommandBuffer commands;
176 {
177 wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
178 wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
179 pass.SetPipeline(pipeline);
180 pass.SetBindGroup(0, bindGroup);
181 pass.Dispatch(1);
182 pass.EndPass();
183
184 commands = encoder.Finish();
185 }
186
187 queue.Submit(1, &commands);
188
189 std::array<float, 4> expectedStruct = {0., 1., 2., 3.};
190 std::array<float, 4> expectedMatrix = {4., 5., 6., 7.};
191 std::array<uint32_t, 4> expectedArray = {8, 9, 10, 11};
192 std::array<float, 4> expectedVector = {12., 13., 14., 15.};
193 EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedStruct.data(), dst, 0, 4);
194 EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedMatrix.data(), dst, 16, 4);
195 EXPECT_BUFFER_U32_RANGE_EQ(expectedArray.data(), dst, 32, 4);
196 EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedVector.data(), dst, 48, 4);
197 }
198
199 DAWN_INSTANTIATE_TEST(ComputeSharedMemoryTests,
200 D3D12Backend(),
201 MetalBackend(),
202 OpenGLBackend(),
203 OpenGLESBackend(),
204 VulkanBackend());
205