1/*------------------------------------------------------------------------
2 * OpenGL Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2017-2019 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
7 * Copyright (c) 2019 NVIDIA Corporation.
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 *      http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 */ /*!
22 * \file
23 * \brief Subgroups Tests
24 */ /*--------------------------------------------------------------------*/
25
26#include "glcSubgroupsQuadTests.hpp"
27#include "glcSubgroupsTestsUtils.hpp"
28
29#include <string>
30#include <vector>
31
32using namespace tcu;
33using namespace std;
34
35namespace glc
36{
37namespace subgroups
38{
39namespace
40{
41enum OpType
42{
43	OPTYPE_QUAD_BROADCAST = 0,
44	OPTYPE_QUAD_SWAP_HORIZONTAL,
45	OPTYPE_QUAD_SWAP_VERTICAL,
46	OPTYPE_QUAD_SWAP_DIAGONAL,
47	OPTYPE_LAST
48};
49
50static bool checkVertexPipelineStages(std::vector<const void*> datas,
51									  deUint32 width, deUint32)
52{
53	return glc::subgroups::check(datas, width, 1);
54}
55
56static bool checkComputeStage(std::vector<const void*> datas,
57						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
58						 deUint32)
59{
60	return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
61}
62
63std::string getOpTypeName(int opType)
64{
65	switch (opType)
66	{
67		default:
68			DE_FATAL("Unsupported op type");
69			return "";
70		case OPTYPE_QUAD_BROADCAST:
71			return "subgroupQuadBroadcast";
72		case OPTYPE_QUAD_SWAP_HORIZONTAL:
73			return "subgroupQuadSwapHorizontal";
74		case OPTYPE_QUAD_SWAP_VERTICAL:
75			return "subgroupQuadSwapVertical";
76		case OPTYPE_QUAD_SWAP_DIAGONAL:
77			return "subgroupQuadSwapDiagonal";
78	}
79}
80
81struct CaseDefinition
82{
83	int					opType;
84	ShaderStageFlags	shaderStage;
85	Format				format;
86	int					direction;
87};
88
89void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
90{
91	std::string			swapTable[OPTYPE_LAST];
92
93	subgroups::setFragmentShaderFrameBuffer(programCollection);
94
95	if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
96		subgroups::setVertexShaderFrameBuffer(programCollection);
97
98	swapTable[OPTYPE_QUAD_BROADCAST] = "";
99	swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = uint[](1u, 0u, 3u, 2u);\n";
100	swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = uint[](2u, 3u, 0u, 1u);\n";
101	swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = uint[](3u, 2u, 1u, 0u);\n";
102
103	if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
104	{
105		std::ostringstream	vertexSrc;
106		vertexSrc << "${VERSION_DECL}\n"
107			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
108			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
109			<< "layout(location = 0) in highp vec4 in_position;\n"
110			<< "layout(location = 0) out float result;\n"
111			<< "layout(binding = 0, std140) uniform Buffer0\n"
112			<< "{\n"
113			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
114			<< "};\n"
115			<< "\n"
116			<< "void main (void)\n"
117			<< "{\n"
118			<< "  uvec4 mask = subgroupBallot(true);\n"
119			<< swapTable[caseDef.opType];
120
121		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
122		{
123			vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
124				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
125				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
126		}
127		else
128		{
129			vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
130				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
131				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
132		}
133
134		vertexSrc << "  if (subgroupBallotBitExtract(mask, otherID))\n"
135			<< "  {\n"
136			<< "    result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
137			<< "  }\n"
138			<< "  else\n"
139			<< "  {\n"
140			<< "    result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
141			<< "  }\n"
142			<< "  gl_Position = in_position;\n"
143			<< "  gl_PointSize = 1.0f;\n"
144			<< "}\n";
145		programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
146	}
147	else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
148	{
149		std::ostringstream geometry;
150
151		geometry << "${VERSION_DECL}\n"
152			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
153			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
154			<< "layout(points) in;\n"
155			<< "layout(points, max_vertices = 1) out;\n"
156			<< "layout(location = 0) out float out_color;\n"
157			<< "layout(binding = 0, std140) uniform Buffer0\n"
158			<< "{\n"
159			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
160			<< "};\n"
161			<< "\n"
162			<< "void main (void)\n"
163			<< "{\n"
164			<< "  uvec4 mask = subgroupBallot(true);\n"
165			<< swapTable[caseDef.opType];
166
167		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
168		{
169			geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
170				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
171				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
172		}
173		else
174		{
175			geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
176				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
177				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
178		}
179
180		geometry << "  if (subgroupBallotBitExtract(mask, otherID))\n"
181			<< "  {\n"
182			<< "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
183			<< "  }\n"
184			<< "  else\n"
185			<< "  {\n"
186			<< "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
187			<< "  }\n"
188			<< "  gl_Position = gl_in[0].gl_Position;\n"
189			<< "  EmitVertex();\n"
190			<< "  EndPrimitive();\n"
191			<< "}\n";
192
193		programCollection.add("geometry") << glu::GeometrySource(geometry.str());
194	}
195	else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
196	{
197		std::ostringstream controlSource;
198
199		controlSource << "${VERSION_DECL}\n"
200			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
201			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
202			<< "layout(vertices = 2) out;\n"
203			<< "layout(location = 0) out float out_color[];\n"
204			<< "layout(binding = 0, std140) uniform Buffer0\n"
205			<< "{\n"
206			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
207			<< "};\n"
208			<< "\n"
209			<< "void main (void)\n"
210			<< "{\n"
211			<< "  if (gl_InvocationID == 0)\n"
212			<<"  {\n"
213			<< "    gl_TessLevelOuter[0] = 1.0f;\n"
214			<< "    gl_TessLevelOuter[1] = 1.0f;\n"
215			<< "  }\n"
216			<< "  uvec4 mask = subgroupBallot(true);\n"
217			<< swapTable[caseDef.opType];
218
219		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
220		{
221			controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
222				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
223				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
224		}
225		else
226		{
227			controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
228				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
229				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
230		}
231
232		controlSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
233			<< "  {\n"
234			<< "    out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
235			<< "  }\n"
236			<< "  else\n"
237			<< "  {\n"
238			<< "    out_color[gl_InvocationID] = 1.0; \n"// Invocation we read from was inactive, so we can't verify results!
239			<< "  }\n"
240			<< "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
241			<< "}\n";
242
243		programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
244		subgroups::setTesEvalShaderFrameBuffer(programCollection);
245	}
246	else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
247	{
248		ostringstream evaluationSource;
249		evaluationSource << "${VERSION_DECL}\n"
250			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
251			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
252			<< "layout(isolines, equal_spacing, ccw ) in;\n"
253			<< "layout(location = 0) out float out_color;\n"
254			<< "layout(binding = 0, std140) uniform Buffer0\n"
255			<< "{\n"
256			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
257			<< "};\n"
258			<< "\n"
259			<< "void main (void)\n"
260			<< "{\n"
261			<< "  uvec4 mask = subgroupBallot(true);\n"
262			<< swapTable[caseDef.opType];
263
264		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
265		{
266			evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
267				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
268				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
269		}
270		else
271		{
272			evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
273				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
274				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
275		}
276
277		evaluationSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
278			<< "  {\n"
279			<< "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
280			<< "  }\n"
281			<< "  else\n"
282			<< "  {\n"
283			<< "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
284			<< "  }\n"
285			<< "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
286			<< "}\n";
287
288		subgroups::setTesCtrlShaderFrameBuffer(programCollection);
289		programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
290	}
291	else
292	{
293		DE_FATAL("Unsupported shader stage");
294	}
295}
296
297void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
298{
299	std::string swapTable[OPTYPE_LAST];
300	swapTable[OPTYPE_QUAD_BROADCAST] = "";
301	swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = uint[](1u, 0u, 3u, 2u);\n";
302	swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = uint[](2u, 3u, 0u, 1u);\n";
303	swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = uint[](3u, 2u, 1u, 0u);\n";
304
305	if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
306	{
307		std::ostringstream src;
308
309		src << "${VERSION_DECL}\n"
310			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
311			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
312			<< "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
313			<< "layout(binding = 0, std430) buffer Buffer0\n"
314			<< "{\n"
315			<< "  uint result[];\n"
316			<< "};\n"
317			<< "layout(binding = 1, std430) buffer Buffer1\n"
318			<< "{\n"
319			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
320			<< "};\n"
321			<< "\n"
322			<< "void main (void)\n"
323			<< "{\n"
324			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
325			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
326			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
327			"gl_GlobalInvocationID.x;\n"
328			<< "  uvec4 mask = subgroupBallot(true);\n"
329			<< swapTable[caseDef.opType];
330
331
332		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
333		{
334			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
335				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
336				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
337		}
338		else
339		{
340			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
341				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
342				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
343		}
344
345		src << "  if (subgroupBallotBitExtract(mask, otherID))\n"
346			<< "  {\n"
347			<< "    result[offset] = (op == data[otherID]) ? 1u : 0u;\n"
348			<< "  }\n"
349			<< "  else\n"
350			<< "  {\n"
351			<< "    result[offset] = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
352			<< "  }\n"
353			<< "}\n";
354
355		programCollection.add("comp") << glu::ComputeSource(src.str());
356	}
357	else
358	{
359		std::ostringstream src;
360		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
361		{
362			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
363				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
364				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
365		}
366		else
367		{
368			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
369				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
370				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
371		}
372		const string sourceType = src.str();
373
374		{
375			const string vertex =
376				"${VERSION_DECL}\n"
377				"#extension GL_KHR_shader_subgroup_quad: enable\n"
378				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
379				"layout(binding = 0, std430) buffer Buffer0\n"
380				"{\n"
381				"  uint result[];\n"
382				"} b0;\n"
383				"layout(binding = 4, std430) readonly buffer Buffer4\n"
384				"{\n"
385				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
386				"};\n"
387				"\n"
388				"void main (void)\n"
389				"{\n"
390				"  uvec4 mask = subgroupBallot(true);\n"
391				+ swapTable[caseDef.opType]
392				+ sourceType +
393				"  if (subgroupBallotBitExtract(mask, otherID))\n"
394				"  {\n"
395				"    b0.result[gl_VertexID] = (op == data[otherID]) ? 1u : 0u;\n"
396				"  }\n"
397				"  else\n"
398				"  {\n"
399				"    b0.result[gl_VertexID] = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
400				"  }\n"
401				"  float pixelSize = 2.0f/1024.0f;\n"
402				"  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
403				"  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
404				"}\n";
405			programCollection.add("vert") << glu::VertexSource(vertex);
406		}
407
408		{
409			const string tesc =
410				"${VERSION_DECL}\n"
411				"#extension GL_KHR_shader_subgroup_quad: enable\n"
412				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
413				"layout(vertices=1) out;\n"
414				"layout(binding = 1, std430) buffer Buffer1\n"
415				"{\n"
416				"  uint result[];\n"
417				"} b1;\n"
418				"layout(binding = 4, std430) readonly buffer Buffer4\n"
419				"{\n"
420				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
421				"};\n"
422				"\n"
423				"void main (void)\n"
424				"{\n"
425				"  uvec4 mask = subgroupBallot(true);\n"
426				+ swapTable[caseDef.opType]
427				+ sourceType +
428				"  if (subgroupBallotBitExtract(mask, otherID))\n"
429				"  {\n"
430				"    b1.result[gl_PrimitiveID] = (op == data[otherID]) ? 1u : 0u;\n"
431				"  }\n"
432				"  else\n"
433				"  {\n"
434				"    b1.result[gl_PrimitiveID] = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
435				"  }\n"
436				"  if (gl_InvocationID == 0)\n"
437				"  {\n"
438				"    gl_TessLevelOuter[0] = 1.0f;\n"
439				"    gl_TessLevelOuter[1] = 1.0f;\n"
440				"  }\n"
441				"  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
442				"}\n";
443			programCollection.add("tesc") << glu::TessellationControlSource(tesc);
444		}
445
446		{
447			const string tese =
448				"${VERSION_DECL}\n"
449				"#extension GL_KHR_shader_subgroup_quad: enable\n"
450				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
451				"layout(isolines) in;\n"
452				"layout(binding = 2, std430)  buffer Buffer2\n"
453				"{\n"
454				"  uint result[];\n"
455				"} b2;\n"
456				"layout(binding = 4, std430) readonly buffer Buffer4\n"
457				"{\n"
458				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
459				"};\n"
460				"\n"
461				"void main (void)\n"
462				"{\n"
463				"  uvec4 mask = subgroupBallot(true);\n"
464				+ swapTable[caseDef.opType]
465				+ sourceType +
466				"  if (subgroupBallotBitExtract(mask, otherID))\n"
467				"  {\n"
468				"    b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1u : 0u;\n"
469				"  }\n"
470				"  else\n"
471				"  {\n"
472				"    b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
473				"  }\n"
474				"  float pixelSize = 2.0f/1024.0f;\n"
475				"  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
476				"}\n";
477			programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
478		}
479
480		{
481			const string geometry =
482				// version added by addGeometryShadersFromTemplate
483				"#extension GL_KHR_shader_subgroup_quad: enable\n"
484				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
485				"layout(${TOPOLOGY}) in;\n"
486				"layout(points, max_vertices = 1) out;\n"
487				"layout(binding = 3, std430) buffer Buffer3\n"
488				"{\n"
489				"  uint result[];\n"
490				"} b3;\n"
491				"layout(binding = 4, std430) readonly buffer Buffer4\n"
492				"{\n"
493				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
494				"};\n"
495				"\n"
496				"void main (void)\n"
497				"{\n"
498				"  uvec4 mask = subgroupBallot(true);\n"
499				+ swapTable[caseDef.opType]
500				+ sourceType +
501				"  if (subgroupBallotBitExtract(mask, otherID))\n"
502				"  {\n"
503				"    b3.result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1u : 0u;\n"
504				"  }\n"
505				"  else\n"
506				"  {\n"
507				"    b3.result[gl_PrimitiveIDIn] = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
508				"  }\n"
509				"  gl_Position = gl_in[0].gl_Position;\n"
510				"  EmitVertex();\n"
511				"  EndPrimitive();\n"
512				"}\n";
513			subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
514		}
515
516		{
517			const string fragment =
518				"${VERSION_DECL}\n"
519				"#extension GL_KHR_shader_subgroup_quad: enable\n"
520				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
521				"precision highp int;\n"
522				"precision highp float;\n"
523				"layout(location = 0) out uint result;\n"
524				"layout(binding = 4, std430) readonly buffer Buffer4\n"
525				"{\n"
526				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
527				"};\n"
528				"void main (void)\n"
529				"{\n"
530				"  uvec4 mask = subgroupBallot(true);\n"
531				+ swapTable[caseDef.opType]
532				+ sourceType +
533				"  if (subgroupBallotBitExtract(mask, otherID))\n"
534				"  {\n"
535				"    result = (op == data[otherID]) ? 1u : 0u;\n"
536				"  }\n"
537				"  else\n"
538				"  {\n"
539				"    result = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
540				"  }\n"
541				"}\n";
542			programCollection.add("fragment") << glu::FragmentSource(fragment);
543		}
544		subgroups::addNoSubgroupShader(programCollection);
545	}
546}
547
548void supportedCheck (Context& context, CaseDefinition caseDef)
549{
550	if (!subgroups::isSubgroupSupported(context))
551		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
552
553	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_QUAD_BIT))
554		TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
555
556
557	if (subgroups::isDoubleFormat(caseDef.format) &&
558			!subgroups::isDoubleSupportedForDevice(context))
559	{
560		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
561	}
562}
563
564tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
565{
566	if (!subgroups::areSubgroupOperationsSupportedForStage(
567				context, caseDef.shaderStage))
568	{
569		if (subgroups::areSubgroupOperationsRequiredForStage(
570					caseDef.shaderStage))
571		{
572			return tcu::TestStatus::fail(
573					   "Shader stage " +
574					   subgroups::getShaderStageName(caseDef.shaderStage) +
575					   " is required to support subgroup operations!");
576		}
577		else
578		{
579			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
580		}
581	}
582
583	subgroups::SSBOData inputData;
584	inputData.format = caseDef.format;
585	inputData.layout = subgroups::SSBOData::LayoutStd140;
586	inputData.numElements = subgroups::maxSupportedSubgroupSize();
587	inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
588	inputData.binding = 0u;
589
590	if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
591		return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
592	else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
593		return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
594	else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
595		return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
596	else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
597		return subgroups::makeTessellationEvaluationFrameBufferTest(context,  FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
598	else
599		TCU_THROW(InternalError, "Unhandled shader stage");
600}
601
602
603tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
604{
605	if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
606	{
607		if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
608		{
609			return tcu::TestStatus::fail(
610					   "Shader stage " +
611					   subgroups::getShaderStageName(caseDef.shaderStage) +
612					   " is required to support subgroup operations!");
613		}
614		subgroups::SSBOData inputData;
615		inputData.format = caseDef.format;
616		inputData.layout = subgroups::SSBOData::LayoutStd430;
617		inputData.numElements = subgroups::maxSupportedSubgroupSize();
618		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
619		inputData.binding = 1u;
620
621		return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
622	}
623	else
624	{
625		int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
626
627		ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
628
629		if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
630		{
631			if ( (stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
632				TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
633			else
634				stages = SHADER_STAGE_FRAGMENT_BIT;
635		}
636
637		if ((ShaderStageFlags)0u == stages)
638			TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
639
640		subgroups::SSBOData inputData;
641		inputData.format			= caseDef.format;
642		inputData.layout			= subgroups::SSBOData::LayoutStd430;
643		inputData.numElements		= subgroups::maxSupportedSubgroupSize();
644		inputData.initializeType	= subgroups::SSBOData::InitializeNonZero;
645		inputData.binding			= 4u;
646		inputData.stages			= stages;
647
648		return subgroups::allStages(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
649	}
650}
651}
652
653deqp::TestCaseGroup* createSubgroupsQuadTests(deqp::Context& testCtx)
654{
655	de::MovePtr<deqp::TestCaseGroup> graphicGroup(new deqp::TestCaseGroup(
656		testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
657	de::MovePtr<deqp::TestCaseGroup> computeGroup(new deqp::TestCaseGroup(
658		testCtx, "compute", "Subgroup arithmetic category tests: compute"));
659	de::MovePtr<deqp::TestCaseGroup> framebufferGroup(new deqp::TestCaseGroup(
660		testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
661
662	const Format formats[] =
663	{
664		FORMAT_R32_SINT, FORMAT_R32G32_SINT, FORMAT_R32G32B32_SINT,
665		FORMAT_R32G32B32A32_SINT, FORMAT_R32_UINT, FORMAT_R32G32_UINT,
666		FORMAT_R32G32B32_UINT, FORMAT_R32G32B32A32_UINT,
667		FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT,
668		FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
669		FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT,
670		FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
671		FORMAT_R32_BOOL, FORMAT_R32G32_BOOL,
672		FORMAT_R32G32B32_BOOL, FORMAT_R32G32B32A32_BOOL,
673	};
674
675	const ShaderStageFlags stages[] =
676	{
677		SHADER_STAGE_VERTEX_BIT,
678		SHADER_STAGE_TESS_EVALUATION_BIT,
679		SHADER_STAGE_TESS_CONTROL_BIT,
680		SHADER_STAGE_GEOMETRY_BIT,
681	};
682
683	for (int direction = 0; direction < 4; ++direction)
684	{
685		for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
686		{
687			const Format format = formats[formatIndex];
688
689			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
690			{
691				const std::string op = de::toLower(getOpTypeName(opTypeIndex));
692				std::ostringstream name;
693				name << de::toLower(op);
694
695				if (OPTYPE_QUAD_BROADCAST == opTypeIndex)
696				{
697					name << "_" << direction;
698				}
699				else
700				{
701					if (0 != direction)
702					{
703						// We don't need direction for swap operations.
704						continue;
705					}
706				}
707
708				name << "_" << subgroups::getFormatNameForGLSL(format);
709
710				{
711					const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format, direction};
712					SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
713				}
714
715				{
716					const CaseDefinition caseDef =
717					{
718						opTypeIndex,
719						SHADER_STAGE_ALL_GRAPHICS,
720						format,
721						direction
722					};
723					SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
724				}
725				for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
726				{
727					const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format, direction};
728					SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(framebufferGroup.get(), name.str()+"_"+ getShaderStageName(caseDef.shaderStage), "",
729												supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
730				}
731
732			}
733		}
734	}
735
736	de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(
737		testCtx, "quad", "Subgroup quad category tests"));
738
739	group->addChild(graphicGroup.release());
740	group->addChild(computeGroup.release());
741	group->addChild(framebufferGroup.release());
742
743	return group.release();
744}
745} // subgroups
746} // glc
747