1/*------------------------------------------------------------------------
2 * OpenGL Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2017-2019 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
7 * Copyright (c) 2019 NVIDIA Corporation.
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 *      http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 */ /*!
22 * \file
23 * \brief Subgroups Tests
24 */ /*--------------------------------------------------------------------*/
25
26#include "glcSubgroupsClusteredTests.hpp"
27#include "glcSubgroupsTestsUtils.hpp"
28
29#include <string>
30#include <vector>
31
32using namespace tcu;
33using namespace std;
34
35namespace glc
36{
37namespace subgroups
38{
39namespace
40{
41enum OpType
42{
43	OPTYPE_CLUSTERED_ADD = 0,
44	OPTYPE_CLUSTERED_MUL,
45	OPTYPE_CLUSTERED_MIN,
46	OPTYPE_CLUSTERED_MAX,
47	OPTYPE_CLUSTERED_AND,
48	OPTYPE_CLUSTERED_OR,
49	OPTYPE_CLUSTERED_XOR,
50	OPTYPE_CLUSTERED_LAST
51};
52
53static bool checkVertexPipelineStages(std::vector<const void*> datas,
54									  deUint32 width, deUint32)
55{
56	return glc::subgroups::check(datas, width, 1);
57}
58
59static bool checkComputeStage(std::vector<const void*> datas,
60						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
61						 deUint32)
62{
63	return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
64}
65
66std::string getOpTypeName(int opType)
67{
68	switch (opType)
69	{
70		default:
71			DE_FATAL("Unsupported op type");
72			return "";
73		case OPTYPE_CLUSTERED_ADD:
74			return "subgroupClusteredAdd";
75		case OPTYPE_CLUSTERED_MUL:
76			return "subgroupClusteredMul";
77		case OPTYPE_CLUSTERED_MIN:
78			return "subgroupClusteredMin";
79		case OPTYPE_CLUSTERED_MAX:
80			return "subgroupClusteredMax";
81		case OPTYPE_CLUSTERED_AND:
82			return "subgroupClusteredAnd";
83		case OPTYPE_CLUSTERED_OR:
84			return "subgroupClusteredOr";
85		case OPTYPE_CLUSTERED_XOR:
86			return "subgroupClusteredXor";
87	}
88}
89
90std::string getOpTypeOperation(int opType, Format format, std::string lhs, std::string rhs)
91{
92	switch (opType)
93	{
94		default:
95			DE_FATAL("Unsupported op type");
96			return "";
97		case OPTYPE_CLUSTERED_ADD:
98			return lhs + " + " + rhs;
99		case OPTYPE_CLUSTERED_MUL:
100			return lhs + " * " + rhs;
101		case OPTYPE_CLUSTERED_MIN:
102			switch (format)
103			{
104				default:
105					return "min(" + lhs + ", " + rhs + ")";
106				case FORMAT_R32_SFLOAT:
107				case FORMAT_R64_SFLOAT:
108					return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))";
109				case FORMAT_R32G32_SFLOAT:
110				case FORMAT_R32G32B32_SFLOAT:
111				case FORMAT_R32G32B32A32_SFLOAT:
112				case FORMAT_R64G64_SFLOAT:
113				case FORMAT_R64G64B64_SFLOAT:
114				case FORMAT_R64G64B64A64_SFLOAT:
115					return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
116			}
117		case OPTYPE_CLUSTERED_MAX:
118			switch (format)
119			{
120				default:
121					return "max(" + lhs + ", " + rhs + ")";
122				case FORMAT_R32_SFLOAT:
123				case FORMAT_R64_SFLOAT:
124					return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))";
125				case FORMAT_R32G32_SFLOAT:
126				case FORMAT_R32G32B32_SFLOAT:
127				case FORMAT_R32G32B32A32_SFLOAT:
128				case FORMAT_R64G64_SFLOAT:
129				case FORMAT_R64G64B64_SFLOAT:
130				case FORMAT_R64G64B64A64_SFLOAT:
131					return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
132			}
133		case OPTYPE_CLUSTERED_AND:
134			switch (format)
135			{
136				default:
137					return lhs + " & " + rhs;
138				case FORMAT_R32_BOOL:
139					return lhs + " && " + rhs;
140				case FORMAT_R32G32_BOOL:
141					return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
142				case FORMAT_R32G32B32_BOOL:
143					return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)";
144				case FORMAT_R32G32B32A32_BOOL:
145					return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)";
146			}
147		case OPTYPE_CLUSTERED_OR:
148			switch (format)
149			{
150				default:
151					return lhs + " | " + rhs;
152				case FORMAT_R32_BOOL:
153					return lhs + " || " + rhs;
154				case FORMAT_R32G32_BOOL:
155					return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
156				case FORMAT_R32G32B32_BOOL:
157					return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)";
158				case FORMAT_R32G32B32A32_BOOL:
159					return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)";
160			}
161		case OPTYPE_CLUSTERED_XOR:
162			switch (format)
163			{
164				default:
165					return lhs + " ^ " + rhs;
166				case FORMAT_R32_BOOL:
167					return lhs + " ^^ " + rhs;
168				case FORMAT_R32G32_BOOL:
169					return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
170				case FORMAT_R32G32B32_BOOL:
171					return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)";
172				case FORMAT_R32G32B32A32_BOOL:
173					return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)";
174			}
175	}
176}
177
178std::string getIdentity(int opType, Format format)
179{
180	bool isFloat = false;
181	bool isInt = false;
182	bool isUnsigned = false;
183
184	switch (format)
185	{
186		default:
187			DE_FATAL("Unhandled format!");
188			break;
189		case FORMAT_R32_SINT:
190		case FORMAT_R32G32_SINT:
191		case FORMAT_R32G32B32_SINT:
192		case FORMAT_R32G32B32A32_SINT:
193			isInt = true;
194			break;
195		case FORMAT_R32_UINT:
196		case FORMAT_R32G32_UINT:
197		case FORMAT_R32G32B32_UINT:
198		case FORMAT_R32G32B32A32_UINT:
199			isUnsigned = true;
200			break;
201		case FORMAT_R32_SFLOAT:
202		case FORMAT_R32G32_SFLOAT:
203		case FORMAT_R32G32B32_SFLOAT:
204		case FORMAT_R32G32B32A32_SFLOAT:
205		case FORMAT_R64_SFLOAT:
206		case FORMAT_R64G64_SFLOAT:
207		case FORMAT_R64G64B64_SFLOAT:
208		case FORMAT_R64G64B64A64_SFLOAT:
209			isFloat = true;
210			break;
211		case FORMAT_R32_BOOL:
212		case FORMAT_R32G32_BOOL:
213		case FORMAT_R32G32B32_BOOL:
214		case FORMAT_R32G32B32A32_BOOL:
215			break; // bool types are not anything
216	}
217
218	switch (opType)
219	{
220		default:
221			DE_FATAL("Unsupported op type");
222			return "";
223		case OPTYPE_CLUSTERED_ADD:
224			return subgroups::getFormatNameForGLSL(format) + "(0)";
225		case OPTYPE_CLUSTERED_MUL:
226			return subgroups::getFormatNameForGLSL(format) + "(1)";
227		case OPTYPE_CLUSTERED_MIN:
228			if (isFloat)
229			{
230				return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
231			}
232			else if (isInt)
233			{
234				return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
235			}
236			else if (isUnsigned)
237			{
238				return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
239			}
240			else
241			{
242				DE_FATAL("Unhandled case");
243				return "";
244			}
245		case OPTYPE_CLUSTERED_MAX:
246			if (isFloat)
247			{
248				return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
249			}
250			else if (isInt)
251			{
252				return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
253			}
254			else if (isUnsigned)
255			{
256				return subgroups::getFormatNameForGLSL(format) + "(0u)";
257			}
258			else
259			{
260				DE_FATAL("Unhandled case");
261				return "";
262			}
263		case OPTYPE_CLUSTERED_AND:
264			return subgroups::getFormatNameForGLSL(format) + "(~0)";
265		case OPTYPE_CLUSTERED_OR:
266			return subgroups::getFormatNameForGLSL(format) + "(0)";
267		case OPTYPE_CLUSTERED_XOR:
268			return subgroups::getFormatNameForGLSL(format) + "(0)";
269	}
270}
271
272std::string getCompare(int opType, Format format, std::string lhs, std::string rhs)
273{
274	std::string formatName = subgroups::getFormatNameForGLSL(format);
275	switch (format)
276	{
277		default:
278			return "all(equal(" + lhs + ", " + rhs + "))";
279		case FORMAT_R32_BOOL:
280		case FORMAT_R32_UINT:
281		case FORMAT_R32_SINT:
282			return "(" + lhs + " == " + rhs + ")";
283		case FORMAT_R32_SFLOAT:
284		case FORMAT_R64_SFLOAT:
285			switch (opType)
286			{
287				default:
288					return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
289				case OPTYPE_CLUSTERED_MIN:
290				case OPTYPE_CLUSTERED_MAX:
291					return "(" + lhs + " == " + rhs + ")";
292			}
293		case FORMAT_R32G32_SFLOAT:
294		case FORMAT_R32G32B32_SFLOAT:
295		case FORMAT_R32G32B32A32_SFLOAT:
296		case FORMAT_R64G64_SFLOAT:
297		case FORMAT_R64G64B64_SFLOAT:
298		case FORMAT_R64G64B64A64_SFLOAT:
299			switch (opType)
300			{
301				default:
302					return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
303				case OPTYPE_CLUSTERED_MIN:
304				case OPTYPE_CLUSTERED_MAX:
305					return "all(equal(" + lhs + ", " + rhs + "))";
306			}
307	}
308}
309
310struct CaseDefinition
311{
312	int					opType;
313	ShaderStageFlags	shaderStage;
314	Format				format;
315};
316
317std::string getBodySource(CaseDefinition caseDef)
318{
319	std::ostringstream bdy;
320	bdy << "  bool tempResult = true;\n";
321
322	for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2)
323	{
324		bdy	<< "  {\n"
325			<< "    const uint clusterSize = " << i << "u;\n"
326			<< "    if (clusterSize <= gl_SubgroupSize)\n"
327			<< "    {\n"
328			<< "      " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
329			<< getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID], clusterSize);\n"
330			<< "      for (uint clusterOffset = 0u; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n"
331			<< "      {\n"
332			<< "        " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
333			<< getIdentity(caseDef.opType, caseDef.format) << ";\n"
334			<< "        for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n"
335			<< "        {\n"
336			<< "          if (subgroupBallotBitExtract(mask, index))\n"
337			<< "          {\n"
338			<< "            ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
339			<< "          }\n"
340			<< "        }\n"
341			<< "        if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + clusterSize)))\n"
342			<< "        {\n"
343			<< "          if (!" << getCompare(caseDef.opType, caseDef.format, "ref", "op") << ")\n"
344			<< "          {\n"
345			<< "            tempResult = false;\n"
346			<< "          }\n"
347			<< "        }\n"
348			<< "      }\n"
349			<< "    }\n"
350			<< "  }\n";
351	}
352	return bdy.str();
353}
354
355void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
356{
357	subgroups::setFragmentShaderFrameBuffer(programCollection);
358
359	if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
360		subgroups::setVertexShaderFrameBuffer(programCollection);
361
362	std::string bdy = getBodySource(caseDef);
363
364	if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
365	{
366		std::ostringstream				vertexSrc;
367		vertexSrc << "${VERSION_DECL}\n"
368			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
369			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
370			<< "layout(location = 0) in highp vec4 in_position;\n"
371			<< "layout(location = 0) out float out_color;\n"
372			<< "layout(binding = 0, std140) uniform Buffer0\n"
373			<< "{\n"
374			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
375			<< "};\n"
376			<< "\n"
377			<< "void main (void)\n"
378			<< "{\n"
379			<< "  uvec4 mask = subgroupBallot(true);\n"
380			<< bdy
381			<< "  out_color = float(tempResult ? 1 : 0);\n"
382			<< "  gl_Position = in_position;\n"
383			<< "  gl_PointSize = 1.0f;\n"
384			<< "}\n";
385		programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
386	}
387	else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
388	{
389		std::ostringstream geometry;
390
391		geometry  << "${VERSION_DECL}\n"
392			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
393			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
394			<< "layout(points) in;\n"
395			<< "layout(points, max_vertices = 1) out;\n"
396			<< "layout(location = 0) out float out_color;\n"
397			<< "layout(binding = 0, std140) uniform Buffer0\n"
398			<< "{\n"
399			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
400			<< "};\n"
401			<< "\n"
402			<< "void main (void)\n"
403			<< "{\n"
404			<< "  uvec4 mask = subgroupBallot(true);\n"
405			<< bdy
406			<< "  out_color = tempResult ? 1.0 : 0.0;\n"
407			<< "  gl_Position = gl_in[0].gl_Position;\n"
408			<< "  EmitVertex();\n"
409			<< "  EndPrimitive();\n"
410			<< "}\n";
411
412		programCollection.add("geometry") << glu::GeometrySource(geometry.str());
413	}
414	else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
415	{
416		std::ostringstream controlSource;
417
418		controlSource << "${VERSION_DECL}\n"
419			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
420			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
421			<< "layout(vertices = 2) out;\n"
422			<< "layout(location = 0) out float out_color[];\n"
423			<< "layout(binding = 0, std140) uniform Buffer0\n"
424			<< "{\n"
425			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
426			<< "};\n"
427			<< "\n"
428			<< "void main (void)\n"
429			<< "{\n"
430			<< "  if (gl_InvocationID == 0)\n"
431			<<"  {\n"
432			<< "    gl_TessLevelOuter[0] = 1.0f;\n"
433			<< "    gl_TessLevelOuter[1] = 1.0f;\n"
434			<< "  }\n"
435			<< "  uvec4 mask = subgroupBallot(true);\n"
436			<< bdy
437			<< "  out_color[gl_InvocationID] = tempResult ? 1.0 : 0.0;\n"
438			<< "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
439			<< "}\n";
440
441		programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
442		subgroups::setTesEvalShaderFrameBuffer(programCollection);
443	}
444	else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
445	{
446		std::ostringstream evaluationSource;
447
448		evaluationSource << "${VERSION_DECL}\n"
449			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
450			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
451			<< "layout(isolines, equal_spacing, ccw ) in;\n"
452			<< "layout(location = 0) out float out_color;\n"
453			<< "layout(binding = 0, std140) uniform Buffer0\n"
454			<< "{\n"
455			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
456			<< "};\n"
457			<< "\n"
458			<< "void main (void)\n"
459			<< "{\n"
460			<< "  uvec4 mask = subgroupBallot(true);\n"
461			<< bdy
462			<< "  out_color = tempResult ? 1.0 : 0.0;\n"
463			<< "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
464			<< "}\n";
465
466		subgroups::setTesCtrlShaderFrameBuffer(programCollection);
467		programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
468	}
469	else
470	{
471		DE_FATAL("Unsupported shader stage");
472	}
473}
474
475void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
476{
477	std::string bdy = getBodySource(caseDef);
478
479	if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
480	{
481		std::ostringstream src;
482
483		src << "${VERSION_DECL}\n"
484			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
485			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
486			<< "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
487			<< "layout(binding = 0, std430) buffer Buffer0\n"
488			<< "{\n"
489			<< "  uint result[];\n"
490			<< "};\n"
491			<< "layout(binding = 1, std430) buffer Buffer1\n"
492			<< "{\n"
493			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
494			<< "};\n"
495			<< "\n"
496			<< "void main (void)\n"
497			<< "{\n"
498			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
499			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
500			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
501			"gl_GlobalInvocationID.x;\n"
502			<< "  uvec4 mask = subgroupBallot(true);\n"
503			<< bdy
504			<< "  result[offset] = tempResult ? 1u : 0u;\n"
505			<< "}\n";
506
507		programCollection.add("comp") << glu::ComputeSource(src.str());
508	}
509	else
510	{
511		{
512			const string vertex =
513				"${VERSION_DECL}\n"
514				"#extension GL_KHR_shader_subgroup_clustered: enable\n"
515				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
516				"layout(binding = 0, std430) buffer Buffer0\n"
517				"{\n"
518				"  uint result[];\n"
519				"} b0;\n"
520				"layout(binding = 4, std430) readonly buffer Buffer4\n"
521				"{\n"
522				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
523				"};\n"
524				"\n"
525				"void main (void)\n"
526				"{\n"
527				"  uvec4 mask = subgroupBallot(true);\n"
528				+ bdy +
529				"  b0.result[gl_VertexID] = tempResult ? 1u : 0u;\n"
530				"  float pixelSize = 2.0f/1024.0f;\n"
531				"  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
532				"  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
533				"}\n";
534
535			programCollection.add("vert") << glu::VertexSource(vertex);
536		}
537
538		{
539			const string tesc =
540			"${VERSION_DECL}\n"
541			"#extension GL_KHR_shader_subgroup_clustered: enable\n"
542			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
543			"layout(vertices=1) out;\n"
544			"layout(binding = 1, std430) buffer Buffer1\n"
545			"{\n"
546			"  uint result[];\n"
547			"} b1;\n"
548			"layout(binding = 4, std430) readonly buffer Buffer4\n"
549			"{\n"
550			"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
551			"};\n"
552			"\n"
553			"void main (void)\n"
554			"{\n"
555			"  uvec4 mask = subgroupBallot(true);\n"
556			+ bdy +
557			"  b1.result[gl_PrimitiveID] = tempResult ? 1u : 0u;\n"
558			"  if (gl_InvocationID == 0)\n"
559			"  {\n"
560			"    gl_TessLevelOuter[0] = 1.0f;\n"
561			"    gl_TessLevelOuter[1] = 1.0f;\n"
562			"  }\n"
563			"  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
564			"}\n";
565
566			programCollection.add("tesc") << glu::TessellationControlSource(tesc);
567		}
568
569		{
570			const string tese =
571				"${VERSION_DECL}\n"
572				"#extension GL_KHR_shader_subgroup_clustered: enable\n"
573				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
574				"layout(isolines) in;\n"
575				"layout(binding = 2, std430) buffer Buffer2\n"
576				"{\n"
577				"  uint result[];\n"
578				"} b2;\n"
579				"layout(binding = 4, std430) readonly buffer Buffer4\n"
580				"{\n"
581				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
582				"};\n"
583				"\n"
584				"void main (void)\n"
585				"{\n"
586				"  uvec4 mask = subgroupBallot(true);\n"
587				+ bdy +
588				"  b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = tempResult ? 1u : 0u;\n"
589				"  float pixelSize = 2.0f/1024.0f;\n"
590				"  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
591				"}\n";
592			programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
593		}
594
595		{
596			const string geometry =
597				// version string added by addGeometryShadersFromTemplate
598				"#extension GL_KHR_shader_subgroup_clustered: enable\n"
599				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
600				"layout(${TOPOLOGY}) in;\n"
601				"layout(points, max_vertices = 1) out;\n"
602				"layout(binding = 3, std430) buffer Buffer3\n"
603				"{\n"
604				"  uint result[];\n"
605				"} b3;\n"
606				"layout(binding = 4, std430) readonly buffer Buffer4\n"
607				"{\n"
608				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
609				"};\n"
610				"\n"
611				"void main (void)\n"
612				"{\n"
613				"  uvec4 mask = subgroupBallot(true);\n"
614				+ bdy +
615				"  b3.result[gl_PrimitiveIDIn] = tempResult ? 1u : 0u;\n"
616				"  gl_Position = gl_in[0].gl_Position;\n"
617				"  EmitVertex();\n"
618				"  EndPrimitive();\n"
619				"}\n";
620			subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
621		}
622
623		{
624			const string fragment =
625				"${VERSION_DECL}\n"
626				"#extension GL_KHR_shader_subgroup_clustered: enable\n"
627				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
628				"precision highp int;\n"
629				"precision highp float;\n"
630				"layout(location = 0) out uint result;\n"
631				"layout(binding = 4, std430) readonly buffer Buffer4\n"
632				"{\n"
633				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
634				"};\n"
635				"void main (void)\n"
636				"{\n"
637				"  uvec4 mask = subgroupBallot(true);\n"
638				+ bdy +
639				"  result = tempResult ? 1u : 0u;\n"
640				"}\n";
641			programCollection.add("fragment") << glu::FragmentSource(fragment);
642		}
643
644		subgroups::addNoSubgroupShader(programCollection);
645	}
646}
647
648void supportedCheck (Context& context, CaseDefinition caseDef)
649{
650	if (!subgroups::isSubgroupSupported(context))
651		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
652
653	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_CLUSTERED_BIT))
654		TCU_THROW(NotSupportedError, "Device does not support subgroup clustered operations");
655
656	if (subgroups::isDoubleFormat(caseDef.format) &&
657			!subgroups::isDoubleSupportedForDevice(context))
658	{
659		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
660	}
661}
662
663tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
664{
665	if (!subgroups::areSubgroupOperationsSupportedForStage(
666				context, caseDef.shaderStage))
667	{
668		if (subgroups::areSubgroupOperationsRequiredForStage(
669					caseDef.shaderStage))
670		{
671			return tcu::TestStatus::fail(
672					   "Shader stage " +
673					   subgroups::getShaderStageName(caseDef.shaderStage) +
674					   " is required to support subgroup operations!");
675		}
676		else
677		{
678			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
679		}
680	}
681
682	subgroups::SSBOData inputData;
683	inputData.format = caseDef.format;
684	inputData.layout = subgroups::SSBOData::LayoutStd140;
685	inputData.numElements = subgroups::maxSupportedSubgroupSize();
686	inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
687	inputData.binding = 0u;
688
689	if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
690		return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
691	else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
692		return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
693	else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
694		return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
695	else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
696		return subgroups::makeTessellationEvaluationFrameBufferTest(context,  FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
697	else
698		TCU_THROW(InternalError, "Unhandled shader stage");
699}
700
701tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
702{
703	if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
704	{
705		if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
706		{
707				return tcu::TestStatus::fail(
708						   "Shader stage " +
709						   subgroups::getShaderStageName(caseDef.shaderStage) +
710						   " is required to support subgroup operations!");
711		}
712		subgroups::SSBOData inputData;
713		inputData.format = caseDef.format;
714		inputData.layout = subgroups::SSBOData::LayoutStd430;
715		inputData.numElements = subgroups::maxSupportedSubgroupSize();
716		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
717		inputData.binding = 1u;
718
719		return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
720	}
721	else
722	{
723		int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
724
725		ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
726
727		if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
728		{
729			if ( (stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
730				TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
731			else
732				stages = SHADER_STAGE_FRAGMENT_BIT;
733		}
734
735		if ((ShaderStageFlags)0u == stages)
736			TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
737
738		subgroups::SSBOData inputData;
739		inputData.format			= caseDef.format;
740		inputData.layout			= subgroups::SSBOData::LayoutStd430;
741		inputData.numElements		= subgroups::maxSupportedSubgroupSize();
742		inputData.initializeType	= subgroups::SSBOData::InitializeNonZero;
743		inputData.binding			= 4u;
744		inputData.stages			= stages;
745
746		return subgroups::allStages(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
747	}
748}
749}
750
751deqp::TestCaseGroup* createSubgroupsClusteredTests(deqp::Context& testCtx)
752{
753	de::MovePtr<deqp::TestCaseGroup> graphicGroup(new deqp::TestCaseGroup(
754		testCtx, "graphics", "Subgroup clustered category tests: graphics"));
755	de::MovePtr<deqp::TestCaseGroup> computeGroup(new deqp::TestCaseGroup(
756		testCtx, "compute", "Subgroup clustered category tests: compute"));
757	de::MovePtr<deqp::TestCaseGroup> framebufferGroup(new deqp::TestCaseGroup(
758		testCtx, "framebuffer", "Subgroup clustered category tests: framebuffer"));
759
760	const ShaderStageFlags stages[] =
761	{
762		SHADER_STAGE_VERTEX_BIT,
763		SHADER_STAGE_TESS_EVALUATION_BIT,
764		SHADER_STAGE_TESS_CONTROL_BIT,
765		SHADER_STAGE_GEOMETRY_BIT
766	};
767
768	const Format formats[] =
769	{
770		FORMAT_R32_SINT, FORMAT_R32G32_SINT, FORMAT_R32G32B32_SINT,
771		FORMAT_R32G32B32A32_SINT, FORMAT_R32_UINT, FORMAT_R32G32_UINT,
772		FORMAT_R32G32B32_UINT, FORMAT_R32G32B32A32_UINT,
773		FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT,
774		FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
775		FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT,
776		FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
777		FORMAT_R32_BOOL, FORMAT_R32G32_BOOL,
778		FORMAT_R32G32B32_BOOL, FORMAT_R32G32B32A32_BOOL,
779	};
780
781	for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
782	{
783		const Format format = formats[formatIndex];
784
785		for (int opTypeIndex = 0; opTypeIndex < OPTYPE_CLUSTERED_LAST; ++opTypeIndex)
786		{
787			bool isBool = false;
788			bool isFloat = false;
789
790			switch (format)
791			{
792				default:
793					break;
794				case FORMAT_R32_SFLOAT:
795				case FORMAT_R32G32_SFLOAT:
796				case FORMAT_R32G32B32_SFLOAT:
797				case FORMAT_R32G32B32A32_SFLOAT:
798				case FORMAT_R64_SFLOAT:
799				case FORMAT_R64G64_SFLOAT:
800				case FORMAT_R64G64B64_SFLOAT:
801				case FORMAT_R64G64B64A64_SFLOAT:
802					isFloat = true;
803					break;
804				case FORMAT_R32_BOOL:
805				case FORMAT_R32G32_BOOL:
806				case FORMAT_R32G32B32_BOOL:
807				case FORMAT_R32G32B32A32_BOOL:
808					isBool = true;
809					break;
810			}
811
812			bool isBitwiseOp = false;
813
814			switch (opTypeIndex)
815			{
816				default:
817					break;
818				case OPTYPE_CLUSTERED_AND:
819				case OPTYPE_CLUSTERED_OR:
820				case OPTYPE_CLUSTERED_XOR:
821					isBitwiseOp = true;
822					break;
823			}
824
825			if (isFloat && isBitwiseOp)
826			{
827				// Skip float with bitwise category.
828				continue;
829			}
830
831			if (isBool && !isBitwiseOp)
832			{
833				// Skip bool when its not the bitwise category.
834				continue;
835			}
836
837			const std::string name = de::toLower(getOpTypeName(opTypeIndex))
838				+"_" + subgroups::getFormatNameForGLSL(format);
839
840			{
841				const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format};
842				SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
843			}
844
845			{
846				const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS, format};
847				SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(), name,
848										"", supportedCheck, initPrograms, test, caseDef);
849			}
850
851			for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
852			{
853				const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
854				SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(framebufferGroup.get(), name +"_" + getShaderStageName(caseDef.shaderStage), "",
855											supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
856			}
857		}
858	}
859	de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(
860		testCtx, "clustered", "Subgroup clustered category tests"));
861
862	group->addChild(graphicGroup.release());
863	group->addChild(computeGroup.release());
864	group->addChild(framebufferGroup.release());
865
866	return group.release();
867}
868
869} // subgroups
870} // glc
871