1/*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2018-2019 NVIDIA Corporation
7 * Copyright (c) 2023 LunarG, Inc.
8 * Copyright (c) 2023 Nintendo
9 *
10 * Licensed under the Apache License, Version 2.0 (the "License");
11 * you may not use this file except in compliance with the License.
12 * You may obtain a copy of the License at
13 *
14 *	  http://www.apache.org/licenses/LICENSE-2.0
15 *
16 * Unless required by applicable law or agreed to in writing, software
17 * distributed under the License is distributed on an "AS IS" BASIS,
18 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 * See the License for the specific language governing permissions and
20 * limitations under the License.
21 *
22 *//*!
23 * \file
24 * \brief Vulkan Cooperative Matrix tests
25 *//*--------------------------------------------------------------------*/
26
27#include "vktComputeCooperativeMatrixTests.hpp"
28
29#include "vkBufferWithMemory.hpp"
30#include "vkImageWithMemory.hpp"
31#include "vkQueryUtil.hpp"
32#include "vkBuilderUtil.hpp"
33#include "vkCmdUtil.hpp"
34#include "vkTypeUtil.hpp"
35#include "vkObjUtil.hpp"
36
37#include "vktTestGroupUtil.hpp"
38#include "vktTestCase.hpp"
39
40#include "deDefs.h"
41#include "deFloat16.h"
42#include "deMath.h"
43#include "deRandom.h"
44#include "deSharedPtr.hpp"
45#include "deString.h"
46
47#include "tcuTestCase.hpp"
48#include "tcuTestLog.hpp"
49
50#include <string>
51#include <sstream>
52#include <set>
53#include <algorithm>
54
55namespace vkt
56{
57namespace compute
58{
59namespace
60{
61using namespace vk;
62using namespace std;
63
64//#define COOPERATIVE_MATRIX_EXTENDED_DEBUG 1
65
66DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT16_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT16_NV);
67DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT32_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT32_NV);
68DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT64_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT64_NV);
69DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT8_KHR   == (uint32_t)VK_COMPONENT_TYPE_SINT8_NV  );
70DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT16_KHR  == (uint32_t)VK_COMPONENT_TYPE_SINT16_NV );
71DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT32_KHR  == (uint32_t)VK_COMPONENT_TYPE_SINT32_NV );
72DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT64_KHR  == (uint32_t)VK_COMPONENT_TYPE_SINT64_NV );
73DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT8_KHR   == (uint32_t)VK_COMPONENT_TYPE_UINT8_NV  );
74DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT16_KHR  == (uint32_t)VK_COMPONENT_TYPE_UINT16_NV );
75DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT32_KHR  == (uint32_t)VK_COMPONENT_TYPE_UINT32_NV );
76DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT64_KHR  == (uint32_t)VK_COMPONENT_TYPE_UINT64_NV );
77
78DE_STATIC_ASSERT((uint32_t)VK_SCOPE_DEVICE_KHR       == (uint32_t)VK_SCOPE_DEVICE_NV);
79DE_STATIC_ASSERT((uint32_t)VK_SCOPE_WORKGROUP_KHR    == (uint32_t)VK_SCOPE_WORKGROUP_NV);
80DE_STATIC_ASSERT((uint32_t)VK_SCOPE_SUBGROUP_KHR     == (uint32_t)VK_SCOPE_SUBGROUP_NV);
81DE_STATIC_ASSERT((uint32_t)VK_SCOPE_QUEUE_FAMILY_KHR == (uint32_t)VK_SCOPE_QUEUE_FAMILY_NV);
82
83typedef enum
84{
85	UT_NV = 0,
86	UT_KHR_A,
87	UT_KHR_B,
88	UT_KHR_Result,
89} UseType;
90
91typedef enum
92{
93	TT_LENGTH = 0,
94	TT_CONSTANT,
95	TT_CONVERT,
96	TT_COMPOSITE,
97	TT_COMPOSITE_RVALUE,
98	TT_ADD,
99	TT_SUB,
100	TT_DIV,
101	TT_MUL,
102	TT_NEGATE,
103	TT_MATRIXTIMESSCALAR,
104	TT_FUNC,
105	TT_MATRIXMULADD,
106	TT_COMPOSITE_ARRAY,
107	TT_MATRIXMULADD_ARRAY,
108	TT_MATRIXMULADD_SATURATED,
109	TT_MATRIXMULADD_WRAPPING,
110	TT_MATRIXMULADD_STRIDE0,
111} TestType;
112
113typedef enum
114{
115	SC_BUFFER = 0,
116	SC_WORKGROUP,
117	SC_WORKGROUP_VARIABLE_POINTERS,
118	SC_BUFFER_VARIABLE_POINTERS,
119	SC_PHYSICAL_STORAGE_BUFFER,
120} StorageClass;
121
122enum SubgroupSizeMode
123{
124	SUBGROUP_SIZE_NONE = 0,
125	SUBGROUP_SIZE_MIN = 1,
126	SUBGROUP_SIZE_MAX = 2,
127};
128
129const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
130
131struct CaseDef
132{
133	TestType							testType;
134	deUint32							subgroupsPerWorkgroupX;
135	deUint32							subgroupsPerWorkgroupY;
136	deUint32							workgroupsX;
137	deUint32							workgroupsY;
138	VkComponentTypeKHR					inputType;
139	VkComponentTypeKHR					outputType;
140	bool								colMajor;
141	StorageClass						storageClass;
142	UseType								useType;
143	SubgroupSizeMode					subgroupSizeMode;
144	vk::ComputePipelineConstructionType	computePipelineConstructionType;
145};
146
147bool isKhr (UseType useType)
148{
149	return useType != UT_NV;
150}
151
152bool isMatrixMulAddOp (TestType testType)
153{
154	return testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY || testType == TT_MATRIXMULADD_SATURATED || testType == TT_MATRIXMULADD_WRAPPING || testType == TT_MATRIXMULADD_STRIDE0;
155}
156
157template<typename T>
158VkResult getCooperativeMatrixProperties (const InstanceInterface&, VkPhysicalDevice, uint32_t*, T*)
159{
160	TCU_THROW(InternalError, "Not Implementetd");
161}
162
163VkResult getCooperativeMatrixProperties (const InstanceInterface& vki, VkPhysicalDevice physicalDevice, uint32_t* pPropertyCount, VkCooperativeMatrixPropertiesKHR* pProperties)
164{
165	return vki.getPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, pPropertyCount, pProperties);
166}
167
168VkResult getCooperativeMatrixProperties (const InstanceInterface& vki, VkPhysicalDevice physicalDevice, uint32_t* pPropertyCount, VkCooperativeMatrixPropertiesNV* pProperties)
169{
170	return vki.getPhysicalDeviceCooperativeMatrixPropertiesNV(physicalDevice, pPropertyCount, pProperties);
171}
172
173VkCooperativeMatrixPropertiesKHR convertCooperativeMatrixProperties (const VkCooperativeMatrixPropertiesNV& properties)
174{
175	VkCooperativeMatrixPropertiesKHR result = initVulkanStructure();
176
177	result.sType					= (VkStructureType)		properties.sType;
178	result.pNext					= (void*)				properties.pNext;
179	result.MSize					= (uint32_t)			properties.MSize;
180	result.NSize					= (uint32_t)			properties.NSize;
181	result.KSize					= (uint32_t)			properties.KSize;
182	result.AType					= (VkComponentTypeKHR)	properties.AType;
183	result.BType					= (VkComponentTypeKHR)	properties.BType;
184	result.CType					= (VkComponentTypeKHR)	properties.CType;
185	result.ResultType				= (VkComponentTypeKHR)	properties.DType;
186	result.saturatingAccumulation	= (VkBool32)			VK_FALSE;
187	result.scope					= (VkScopeKHR)			properties.scope;
188
189	return result;
190}
191
192std::vector<VkCooperativeMatrixPropertiesKHR> convertCooperativeMatrixProperties (const std::vector <VkCooperativeMatrixPropertiesNV>& properties)
193{
194	std::vector<VkCooperativeMatrixPropertiesKHR> result(properties.size());
195
196	for (size_t i = 0; i < properties.size(); ++i)
197		result[i] = convertCooperativeMatrixProperties(properties[i]);
198
199	return result;
200}
201
202template<typename T>
203void getCooperativeMatrixPropertiesAll (Context& context, std::vector<T>& properties)
204{
205	deUint32	propertyCount	= 0;
206
207	VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount, (T*)DE_NULL));
208
209	if (propertyCount > 0)
210	{
211		const T sample = initVulkanStructureConst();
212
213		properties.resize(propertyCount, sample);
214
215		VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount, properties.data()));
216	}
217	else
218	{
219		properties.clear();
220	}
221}
222
223std::vector<VkCooperativeMatrixPropertiesKHR> getCooperativeMatrixPropertiesConverted (Context& context, const bool khr)
224{
225	std::vector<VkCooperativeMatrixPropertiesKHR> properties;
226
227	if (khr)
228	{
229		getCooperativeMatrixPropertiesAll(context, properties);
230	}
231	else
232	{
233		std::vector<VkCooperativeMatrixPropertiesNV> propertiesNV;
234
235		getCooperativeMatrixPropertiesAll(context, propertiesNV);
236
237		properties = convertCooperativeMatrixProperties(propertiesNV);
238	}
239
240	return properties;
241}
242
243deUint32 getSubgroupSizeFromMode (Context&					context,
244								  const SubgroupSizeMode	subgroupSizeMode)
245{
246#ifndef CTS_USES_VULKANSC
247	const VkPhysicalDeviceSubgroupSizeControlProperties&	subgroupSizeControlProperties = context.getSubgroupSizeControlProperties();
248#else
249	const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT&	subgroupSizeControlProperties = context.getSubgroupSizeControlPropertiesEXT();
250#endif // CTS_USES_VULKANSC
251
252	switch (subgroupSizeMode)
253	{
254		case SUBGROUP_SIZE_MAX:		return subgroupSizeControlProperties.maxSubgroupSize;
255		case SUBGROUP_SIZE_MIN:		return subgroupSizeControlProperties.minSubgroupSize;
256		case SUBGROUP_SIZE_NONE:	return context.getSubgroupProperties().subgroupSize;
257		default:					TCU_THROW(NotSupportedError, "Unsupported Subgroup size");
258	}
259}
260
261
262class CooperativeMatrixTestInstance : public TestInstance
263{
264public:
265						CooperativeMatrixTestInstance	(Context& context, const CaseDef& data);
266						~CooperativeMatrixTestInstance	(void);
267	tcu::TestStatus		iterate							(void);
268private:
269	CaseDef			m_data;
270};
271
272CooperativeMatrixTestInstance::CooperativeMatrixTestInstance (Context& context, const CaseDef& data)
273	: vkt::TestInstance		(context)
274	, m_data				(data)
275{
276}
277
278CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance (void)
279{
280}
281
282class CooperativeMatrixTestCase : public TestCase
283{
284	public:
285								CooperativeMatrixTestCase		(tcu::TestContext& context, const char* name, const CaseDef data);
286								~CooperativeMatrixTestCase	(void);
287	virtual	void				initPrograms		(SourceCollections& programCollection) const;
288	virtual TestInstance*		createInstance		(Context& context) const;
289	virtual void				checkSupport		(Context& context) const;
290
291private:
292	CaseDef					m_data;
293};
294
295CooperativeMatrixTestCase::CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const CaseDef data)
296	: vkt::TestCase	(context, name)
297	, m_data		(data)
298{
299}
300
301CooperativeMatrixTestCase::~CooperativeMatrixTestCase (void)
302{
303}
304
305void CooperativeMatrixTestCase::checkSupport (Context& context) const
306{
307	if (!context.contextSupports(vk::ApiVersion(0, 1, 1, 0)))
308	{
309		TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
310	}
311
312	if (isKhr(m_data.useType))
313	{
314		if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
315		{
316			TCU_THROW(NotSupportedError, "VkPhysicalDeviceCooperativeMatrixFeaturesKHR::cooperativeMatrix not supported");
317		}
318	}
319	else
320	{
321		if (!context.getCooperativeMatrixFeaturesNV().cooperativeMatrix)
322		{
323			TCU_THROW(NotSupportedError, "VkPhysicalDeviceCooperativeMatrixFeaturesNV::cooperativeMatrix not supported");
324		}
325	}
326
327	if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
328	{
329		TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
330	}
331
332	if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
333		!context.getVariablePointersFeatures().variablePointers)
334	{
335		TCU_THROW(NotSupportedError, "variable pointers not supported");
336	}
337
338	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
339	{
340		TCU_THROW(NotSupportedError, "buffer device address not supported");
341	}
342
343	if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
344		(m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_KHR))
345	{
346		TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
347	}
348
349	std::vector<VkCooperativeMatrixPropertiesKHR>	properties		= getCooperativeMatrixPropertiesConverted(context, isKhr(m_data.useType));
350	bool											supported[2]	= { false, false };
351	const auto										isMMA			= isMatrixMulAddOp(m_data.testType);
352	const auto										isMMASat		= m_data.testType == TT_MATRIXMULADD_SATURATED;
353
354	for (size_t i = 0; i < properties.size(); ++i)
355	{
356		const VkCooperativeMatrixPropertiesKHR*	p	= &properties[i];
357
358		if (p->scope != VK_SCOPE_SUBGROUP_KHR)
359			continue;
360
361		if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
362			continue;
363
364		if (isMMA)
365		{
366			if (p->AType == m_data.inputType &&
367				p->BType == m_data.inputType &&
368				p->CType == m_data.outputType &&
369				p->ResultType == m_data.outputType)
370			{
371				supported[0] = supported[1] = true;
372			}
373		}
374		else
375		{
376			const VkComponentTypeKHR types[2] = { m_data.inputType, m_data.outputType };
377
378			for (deUint32 j = 0; j < 2; ++j)
379			{
380				switch (m_data.useType)
381				{
382					case UT_NV:
383					{
384						if (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] || p->ResultType == types[j])
385							supported[j] = true;
386
387						break;
388					}
389					case UT_KHR_A:
390					{
391						if (p->AType == types[j])
392							supported[j] = true;
393
394						break;
395					}
396					case UT_KHR_B:
397					{
398						if (p->BType == types[j])
399							supported[j] = true;
400
401						break;
402					}
403					case UT_KHR_Result:
404					{
405						if (p->ResultType == types[j])
406							supported[j] = true;
407
408						break;
409					}
410					default:
411						TCU_THROW(InternalError, "Unsupported use type");
412				}
413			}
414		}
415	}
416
417	if (!supported[0] || !supported[1])
418		TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
419
420	checkShaderObjectRequirements(context.getInstanceInterface(), context.getPhysicalDevice(), m_data.computePipelineConstructionType);
421}
422
423struct {
424	const char *typeName;
425	const char *coopmatTypeName;
426	deUint32 bits;
427	bool isSigned;
428} componentTypeInfo[] =
429{
430	{ "float16_t",	"fcoopmatNV",	16, true },
431	{ "float32_t",	"fcoopmatNV",	32, true },
432	{ "float64_t",	"fcoopmatNV",	64, true },
433	{ "int8_t",		"icoopmatNV",	8, true },
434	{ "int16_t",	"icoopmatNV",	16, true },
435	{ "int32_t",	"icoopmatNV",	32, true },
436	{ "int64_t",	"icoopmatNV",	64, true },
437	{ "uint8_t",	"ucoopmatNV",	8, false },
438	{ "uint16_t",	"ucoopmatNV",	16, false },
439	{ "uint32_t",	"ucoopmatNV",	32, false },
440	{ "uint64_t",	"ucoopmatNV",	64, false },
441};
442
443bool isFloatType (VkComponentTypeKHR t)
444{
445	switch (t)
446	{
447		case VK_COMPONENT_TYPE_FLOAT16_KHR:
448		case VK_COMPONENT_TYPE_FLOAT32_KHR:
449		case VK_COMPONENT_TYPE_FLOAT64_KHR:
450			return true;
451		default:
452			return false;
453	}
454}
455
456bool isSIntType (VkComponentTypeKHR t)
457{
458	switch (t)
459	{
460		case VK_COMPONENT_TYPE_SINT8_KHR:
461		case VK_COMPONENT_TYPE_SINT16_KHR:
462		case VK_COMPONENT_TYPE_SINT32_KHR:
463		case VK_COMPONENT_TYPE_SINT64_KHR:
464			return true;
465		default:
466			return false;
467	}
468}
469
470void CooperativeMatrixTestCase::initPrograms (SourceCollections& programCollection) const
471{
472	const char*			suffix	= (isKhr(m_data.useType) ? "" : "NV");
473	const char*			ext		= isKhr(m_data.useType)
474								? "#extension GL_KHR_cooperative_matrix : enable\n"
475								: "#extension GL_NV_cooperative_matrix : enable\n"
476								  "#extension GL_NV_integer_cooperative_matrix : enable\n";
477	const char*			sat		= (m_data.testType == TT_MATRIXMULADD_SATURATED) ? ", gl_MatrixOperandsSaturatingAccumulation" : "";
478	std::stringstream	css;
479	css << "#version 450 core\n";
480	css << "#pragma use_vulkan_memory_model\n";
481	css <<
482		"#extension GL_KHR_shader_subgroup_basic : enable\n"
483		"#extension GL_KHR_memory_scope_semantics : enable\n"
484		<< ext <<
485		"#extension GL_EXT_shader_explicit_arithmetic_types : enable\n"
486		"#extension GL_EXT_buffer_reference : enable\n"
487		"// strides overriden by spec constants\n"
488		"layout(constant_id = 2) const int AStride = 1;\n"
489		"layout(constant_id = 3) const int BStride = 1;\n"
490		"layout(constant_id = 4) const int CStride = 1;\n"
491		"layout(constant_id = 5) const int OStride = 1;\n"
492		"layout(constant_id = 6) const int M = 1;\n"
493		"layout(constant_id = 7) const int N = 1;\n"
494		"layout(constant_id = 8) const int K = 1;\n"
495		"layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
496
497	if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
498		css << "#pragma use_variable_pointers\n";
499
500	struct
501	{
502		string rows, cols;
503	} dims[4];
504
505	if (isMatrixMulAddOp(m_data.testType))
506	{
507		dims[0].rows = "M";
508		dims[0].cols = "K";
509		dims[1].rows = "K";
510		dims[1].cols = "N";
511		dims[2].rows = "M";
512		dims[2].cols = "N";
513		dims[3].rows = "M";
514		dims[3].cols = "N";
515	}
516	else
517	{
518		dims[0].rows = "M";
519		dims[0].cols = "N";
520		dims[1].rows = "M";
521		dims[1].cols = "N";
522		dims[2].rows = "M";
523		dims[2].cols = "N";
524		dims[3].rows = "M";
525		dims[3].cols = "N";
526	}
527
528	const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
529	const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
530	const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
531	const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
532
533	css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
534	css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", " << m_data.subgroupsPerWorkgroupY << ");\n";
535
536	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
537	{
538		css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
539		css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
540		css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
541		css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
542		css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; } params;\n";
543	}
544	else
545	{
546		css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
547		css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
548		css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
549		css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
550	}
551
552	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
553	{
554		css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
555		css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
556		css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
557		css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
558	}
559
560	std::stringstream matAType, matBType, matCType, outputMatType;
561
562	if (isKhr(m_data.useType))
563	{
564		const bool	useSame		= !isMatrixMulAddOp(m_data.testType);
565		const char*	sameType	= m_data.useType == UT_KHR_A ? "gl_MatrixUseA"
566								: m_data.useType == UT_KHR_B ? "gl_MatrixUseB"
567								: m_data.useType == UT_KHR_Result ? "gl_MatrixUseAccumulator"
568								: "Invalid use";
569		const char*	atype		= useSame ? sameType : "gl_MatrixUseA";
570		const char*	btype		= useSame ? sameType : "gl_MatrixUseB";
571		const char*	ctype		= useSame ? sameType : "gl_MatrixUseAccumulator";
572		const char*	rtype		= useSame ? sameType : "gl_MatrixUseAccumulator";
573
574		matAType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ", " << atype << ">";
575		matBType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ", " << btype << ">";
576		matCType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ", " << ctype << ">";
577		outputMatType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ", " << rtype << ">";
578	}
579	else
580	{
581		matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ">";
582		matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ">";
583		matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ">";
584		outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ">";
585	}
586
587	css << matAType.str() << " matA;\n";
588	css << matBType.str() << " matB;\n";
589	css << matCType.str() << " matC;\n";
590	css << outputMatType.str() << " matO;\n";
591
592	if (m_data.testType == TT_CONSTANT)
593		css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
594
595	if (m_data.testType == TT_FUNC)
596		css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
597
598	css <<
599		"void main()\n"
600		"{\n"
601		// matrixID is the x,y index of the matrix owned by this subgroup.
602		"   uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
603		"   uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
604
605	if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
606	{
607		css << "   InputA inputA = params.inputA;\n";
608		css << "   InputB inputB = params.inputB;\n";
609		css << "   InputC inputC = params.inputC;\n";
610		css << "   Output outputO = params.outputO;\n";
611	}
612
613	string strides[4];
614	for (deUint32 i = 0; i < 4; ++i)
615	{
616		strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") + de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
617	}
618
619	// element<i> is the starting element in buffer memory.
620	// elementS<i> is the starting element in shared memory.
621	css << "   uint element0 = " << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x;\n"
622		   "   uint element1 = " << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x;\n"
623		   "   uint element2 = " << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x;\n"
624		   "   uint element3 = " << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x;\n"
625		   "   uint elementS0, elementS1, elementS2, elementS3;\n";
626
627	// For shared memory tests, copy the matrix from buffer memory into
628	// workgroup memory. For simplicity, do it all on a single thread.
629	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
630	{
631		const char *name[] =
632		{
633			"sharedA",
634			"sharedB",
635			"sharedC",
636		};
637		const char *inputName[] =
638		{
639			"inputA",
640			"inputB",
641			"inputC",
642		};
643		for (deUint32 m = 0; m < 4; ++m)
644		{
645			string sharedStride = strides[m] + " / workgroupsX";
646			css << "       elementS" << m << " = " << sharedStride << " * " << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + " << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x;\n";
647		}
648		css << "   if (subgroupElect()) {\n";
649		// copy all three input buffers.
650		for (deUint32 m = 0; m < 3; ++m)
651		{
652			string sharedStride = strides[m] + " / workgroupsX";
653			css <<  "       for (int i = 0; i < " << dims[m].rows << "; ++i) {\n"
654					"       for (int j = 0; j < " << dims[m].cols << "; ++j) {\n"
655					"           int localElementInput = " << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
656					"           int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
657					"           " << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m << " + localElementInput];\n"
658					"       }\n"
659					"       }\n";
660			strides[m] = sharedStride;
661		}
662		css << "   }\n";
663		css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
664	}
665
666	const char *colMajorNV  = (m_data.colMajor ? "true" : "false");
667	const char* colMajorKHR = (m_data.colMajor ? "gl_CooperativeMatrixLayoutColumnMajor" : "gl_CooperativeMatrixLayoutRowMajor");
668	const char* colMajor    = (isKhr(m_data.useType) ? colMajorKHR : colMajorNV);
669
670	string loadStrides[3] = { strides[0], strides[1], strides[2] };
671	// Load with a stride of 0
672	if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
673		loadStrides[0] = loadStrides[1] = loadStrides[2] = "0";
674
675	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
676	{
677		css <<  "   coopMatLoad" << suffix << "(matA, sharedA, elementS0, " << loadStrides[0] << ", " << colMajor << ");\n"
678				"   coopMatLoad" << suffix << "(matB, sharedB, elementS1, " << loadStrides[1] << ", " << colMajor << ");\n"
679				"   coopMatLoad" << suffix << "(matC, sharedC, elementS2, " << loadStrides[2] << ", " << colMajor << ");\n";
680	}
681	else
682	{
683		css << "   coopMatLoad" << suffix << "(matA, inputA.x, element0, " << loadStrides[0] << ", " << colMajor << ");\n"
684			   "   coopMatLoad" << suffix << "(matB, inputB.x, element1, " << loadStrides[1] << ", " << colMajor << ");\n"
685			   "   coopMatLoad" << suffix << "(matC, inputC.x, element2, " << loadStrides[2] << ", " << colMajor << ");\n";
686	}
687
688	if (m_data.testType == TT_COMPOSITE_ARRAY ||
689		m_data.testType == TT_MATRIXMULADD_ARRAY)
690	{
691		css << "   " << matAType.str() << " matAArr[2];\n    matAArr[1] = matA; matAArr[0] = " << matAType.str() << "(0.0);\n"
692			   "   " << matBType.str() << " matBArr[2];\n    matBArr[1] = matB; matBArr[0] = " << matBType.str() << "(0.0);\n"
693			   "   " << matCType.str() << " matCArr[2];\n    matCArr[1] = matC; matCArr[0] = " << matCType.str() << "(0.0);\n"
694			   "   " << outputMatType.str() << " matOArr[2];\n";
695	}
696
697	switch (m_data.testType)
698	{
699	default:
700		DE_ASSERT(0);
701		// fall through
702	case TT_LENGTH:
703		css << "   matO = " << outputMatType.str() << "(matO.length());\n";
704		break;
705	case TT_CONSTANT:
706		css << "   matO = matConst;\n";
707		break;
708	case TT_CONVERT:
709		css << "   matO = " << outputMatType.str() << "(matA);\n";
710		break;
711	case TT_COMPOSITE:
712		css << "   " << matAType.str() << " t = " << matAType.str() << "(matB[0]);\n"
713			"   for (int i = 1; i < matA.length(); ++i) {\n"
714			"       matO[i] = matA[i] + matB[i];\n"
715			"   }\n"
716			"   if (matA.length() > 0)\n"
717			"       matO[0] = matA[0] + t[0];\n";
718		break;
719	case TT_COMPOSITE_RVALUE:
720		css << "   for (int i = 1; i < matA.length(); ++i) {\n"
721			   "       matO[i] = matA[i] + matB[i];\n"
722			   "   }\n"
723			   "   " << matAType.str() << " t = matA;\n"
724			   "   if (matA.length() > 0) {\n"
725			   "       matO[0] = (t += matB)[0];\n"
726			   "   }\n";
727		break;
728	case TT_COMPOSITE_ARRAY:
729		css << "   for (int i = 0; i < matA.length(); ++i) {\n"
730			   "       matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
731			   "   }\n";
732		break;
733	case TT_ADD:
734		css << "   matO = matA + matB;\n";
735		break;
736	case TT_SUB:
737		css << "   matO = matA - matB;\n";
738		break;
739	case TT_DIV:
740		css << "   matO = matA / matB;\n";
741		break;
742	case TT_MUL:
743		css << "   matO = matA * matB;\n";
744		break;
745	case TT_NEGATE:
746		css << "   matO = -matA;\n";
747		break;
748	case TT_FUNC:
749		css << "   matO = f(matA);\n";
750		break;
751	case TT_MATRIXTIMESSCALAR:
752		css << "   matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
753		break;
754	case TT_MATRIXMULADD_STRIDE0:
755	case TT_MATRIXMULADD_WRAPPING:
756	case TT_MATRIXMULADD_SATURATED:
757	case TT_MATRIXMULADD:
758		css << "   matO = coopMatMulAdd" << suffix << "(matA, matB, matC" << sat << ");\n";
759		break;
760	case TT_MATRIXMULADD_ARRAY:
761		css << "   matOArr[1] = coopMatMulAdd" << suffix << "(matAArr[1], matBArr[1], matCArr[1]);\n";
762		break;
763	}
764
765	if (m_data.testType == TT_COMPOSITE_ARRAY ||
766		m_data.testType == TT_MATRIXMULADD_ARRAY)
767	{
768		css << "   matOArr[0] = " << outputMatType.str() << "(0.0);\n";
769		css << "   matO = matOArr[1];\n";
770	}
771
772	if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
773	{
774		string sharedStride = strides[3] + " / workgroupsX";
775		css << "   coopMatStore" << suffix << "(matO, sharedO, elementS3, " << sharedStride << ", " << colMajor << ");\n";
776		css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
777		css << "   if (subgroupElect()) {\n";
778		css << "       for (int i = 0; i < " << dims[3].rows << "; ++i) {\n"
779			   "       for (int j = 0; j < " << dims[3].cols << "; ++j) {\n"
780			   "           int localElementInput = " << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
781			   "           int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
782			   "           outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
783			   "       }\n"
784			   "       }\n";
785		css << "   }\n";
786	}
787	else
788	{
789		css << "   coopMatStore" << suffix << "(matO, outputO.x, element3, " << strides[3] << ", " << colMajor << ");\n";
790	}
791
792	css <<
793		"}\n";
794
795	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
796
797	programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
798}
799
800TestInstance* CooperativeMatrixTestCase::createInstance (Context& context) const
801{
802	return new CooperativeMatrixTestInstance(context, m_data);
803}
804
805void setDataFloat (void *base, VkComponentTypeKHR dt, deUint32 i, float value)
806{
807	if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
808	{
809		((float *)base)[i] = value;
810	}
811	else
812	{
813		DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
814		((deFloat16 *)base)[i] = deFloat32To16(value);
815	}
816}
817
818float getDataFloat (void *base, VkComponentTypeKHR dt, deUint32 i)
819{
820	if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
821	{
822		return ((float *)base)[i];
823	}
824	else
825	{
826		DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
827		return deFloat16To32(((deFloat16 *)base)[i]);
828	}
829}
830
831void setDataInt (void *base, VkComponentTypeKHR dt, deUint32 i, deUint32 value)
832{
833	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
834
835	switch (dt)
836	{
837		case VK_COMPONENT_TYPE_UINT8_KHR:	((deUint8  *)base)[i] = (deUint8)value; break;
838		case VK_COMPONENT_TYPE_UINT16_KHR:	((deUint16 *)base)[i] = (deUint16)value; break;
839		case VK_COMPONENT_TYPE_UINT32_KHR:	((deUint32 *)base)[i] = (deUint32)value; break;
840		case VK_COMPONENT_TYPE_SINT8_KHR:	((deInt8  *)base)[i] = (deInt8)value; break;
841		case VK_COMPONENT_TYPE_SINT16_KHR:	((deInt16 *)base)[i] = (deInt16)value; break;
842		case VK_COMPONENT_TYPE_SINT32_KHR:	((deInt32 *)base)[i] = (deInt32)value; break;
843		default:							TCU_THROW(InternalError, "Unsupported type");
844	}
845}
846
847deUint32 getDataInt (void *base, VkComponentTypeKHR dt, deUint32 i)
848{
849	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
850
851	switch (dt)
852	{
853		case VK_COMPONENT_TYPE_UINT8_KHR:	return ((deUint8*)base)[i];
854		case VK_COMPONENT_TYPE_UINT16_KHR:	return ((deUint16*)base)[i];
855		case VK_COMPONENT_TYPE_UINT32_KHR:	return ((deUint32*)base)[i];
856		case VK_COMPONENT_TYPE_SINT8_KHR:	return ((deInt8*)base)[i];
857		case VK_COMPONENT_TYPE_SINT16_KHR:	return ((deInt16*)base)[i];
858		case VK_COMPONENT_TYPE_SINT32_KHR:	return ((deInt32 *)base)[i];
859		default:							TCU_THROW(InternalError, "Unsupported type");
860	}
861}
862
863template <typename T>
864T getDataConvertedToT (void *base, VkComponentTypeKHR dt, deUint32 i)
865{
866	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
867
868	switch (dt)
869	{
870		case VK_COMPONENT_TYPE_UINT8_KHR:	return (T)((deUint8*)base)[i];
871		case VK_COMPONENT_TYPE_UINT16_KHR:	return (T)((deUint16*)base)[i];
872		case VK_COMPONENT_TYPE_UINT32_KHR:	return (T)((deUint32*)base)[i];
873		case VK_COMPONENT_TYPE_SINT8_KHR:	return (T)((deInt8*)base)[i];
874		case VK_COMPONENT_TYPE_SINT16_KHR:	return (T)((deInt16*)base)[i];
875		case VK_COMPONENT_TYPE_SINT32_KHR:	return (T)((deInt32 *)base)[i];
876		case VK_COMPONENT_TYPE_FLOAT32_KHR:
877		{
878			float temp = ((float *)base)[i];
879			if (std::numeric_limits<T>::min() == 0)
880				temp = std::max(temp, 0.0f);
881			return (T)temp;
882		}
883		case VK_COMPONENT_TYPE_FLOAT16_KHR:
884		{
885			float temp = deFloat16To32(((deFloat16 *)base)[i]);
886			if (std::numeric_limits<T>::min() == 0)
887				temp = std::max(temp, 0.0f);
888			return (T)temp;
889		}
890		default:
891			TCU_THROW(InternalError, "Unsupported type");
892	}
893}
894
895template<typename T>
896T satAdd(T a, T b)
897{
898	if (a > 0)
899	{
900		if (b > std::numeric_limits<T>::max() - a)
901			return std::numeric_limits<T>::max();
902	}
903	else if (b < std::numeric_limits<T>::min() - a)
904	{
905		return std::numeric_limits<T>::min();
906	}
907
908	return (T)(a + b);
909}
910
911deUint32 satAddData (VkComponentTypeKHR dt, deUint32 a, deUint32 b)
912{
913	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
914
915	switch (dt)
916	{
917		case VK_COMPONENT_TYPE_UINT8_KHR:	return deMinu32(a + b, std::numeric_limits<deUint8>::max());
918		case VK_COMPONENT_TYPE_UINT16_KHR:	return deMinu32(a + b, std::numeric_limits<deUint16>::max());
919		case VK_COMPONENT_TYPE_UINT32_KHR:	return (a + b >= a) ? a + b : std::numeric_limits<deUint32>::max();
920		case VK_COMPONENT_TYPE_SINT8_KHR:	return (deUint32)satAdd((deInt8)a,  (deInt8)b);
921		case VK_COMPONENT_TYPE_SINT16_KHR:	return (deUint32)satAdd((deInt16)a, (deInt16)b);
922		case VK_COMPONENT_TYPE_SINT32_KHR:	return (deUint32)satAdd((deInt32)a, (deInt32)b);
923		default:							TCU_THROW(InternalError, "Unsupported type");
924	}
925}
926
927deUint32 getLimit (VkComponentTypeKHR dt, bool positive)
928{
929	DE_ASSERT(componentTypeInfo[dt].bits <= 32);
930
931	switch (dt)
932	{
933		case VK_COMPONENT_TYPE_UINT8_KHR:	return deUint32(positive ? std::numeric_limits<deUint8>::max()  : std::numeric_limits<deUint8>::min());
934		case VK_COMPONENT_TYPE_UINT16_KHR:	return deUint32(positive ? std::numeric_limits<deUint16>::max() : std::numeric_limits<deUint16>::min());
935		case VK_COMPONENT_TYPE_UINT32_KHR:	return deUint32(positive ? std::numeric_limits<deUint32>::max() : std::numeric_limits<deUint32>::min());
936		case VK_COMPONENT_TYPE_SINT8_KHR:	return deUint32(positive ? std::numeric_limits<deInt8>::max()   : std::numeric_limits<deInt8>::min());
937		case VK_COMPONENT_TYPE_SINT16_KHR:	return deUint32(positive ? std::numeric_limits<deInt16>::max()  : std::numeric_limits<deInt16>::min());
938		case VK_COMPONENT_TYPE_SINT32_KHR:	return deUint32(positive ? std::numeric_limits<deInt32>::max()  : std::numeric_limits<deInt32>::min());
939		default:							TCU_THROW(InternalError, "Unsupported type");
940	}
941}
942
943void setSingleElementInt (void *data, VkComponentTypeKHR dt, deUint32 start, deUint32 count, deUint32 step, deUint32 at, deUint32 val)
944{
945	for (deUint32 i = 0; i < count; i++)
946		setDataInt(data, dt, start + i * step, (i == at) ? val : 0);
947}
948
949#ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
950string dumpWholeMatrix (void* data, VkComponentTypeKHR dt, bool colMajor, deUint32 matrixElemCount, deUint32 stride)
951{
952	const deUint32		rowsCount	= colMajor ? stride : matrixElemCount / stride;
953	const deUint32		colsCount	= colMajor ? matrixElemCount / stride : stride;
954	bool				floatType	= isFloatType(dt);
955	bool				sIntType	= isSIntType(dt);
956	std::stringstream	ss;
957
958	DE_ASSERT(rowsCount * colsCount == matrixElemCount);
959
960	for (deUint32 r = 0; r < rowsCount; r++)
961	{
962		for (deUint32 c = 0; c < colsCount; c++)
963		{
964			const deUint32 i = colMajor ? rowsCount * c + r : colsCount * r + c;
965
966			if (floatType)
967				ss << getDataFloat(data, dt, i) << "\t";
968			else if (sIntType)
969				ss << (deInt32)getDataInt(data, dt, i) << "\t";
970			else
971				ss << getDataInt(data, dt, i) << "\t";
972		}
973
974		ss << std::endl;
975	}
976
977	return ss.str();
978}
979#endif
980
981tcu::TestStatus CooperativeMatrixTestInstance::iterate (void)
982{
983	const DeviceInterface&	vk						= m_context.getDeviceInterface();
984	const VkDevice			device					= m_context.getDevice();
985	Allocator&				allocator				= m_context.getDefaultAllocator();
986	MemoryRequirement		memoryDeviceAddress		= m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
987													  m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ? MemoryRequirement::DeviceAddress : MemoryRequirement::Any;
988	qpTestResult			finalres				= QP_TEST_RESULT_NOT_SUPPORTED;
989	tcu::TestLog&			log						= m_context.getTestContext().getLog();
990	const bool				saturated				= (m_data.testType == TT_MATRIXMULADD_SATURATED);
991	const deUint32			subgroupSize			= getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode);
992	const float				epsilon					= 1.0f / float(1ull<<17); // 131072 is epsilon circa 1e-5
993
994	deRandom rnd;
995	deRandom_init(&rnd, 1234);
996
997	std::vector<VkCooperativeMatrixPropertiesKHR>	properties = getCooperativeMatrixPropertiesConverted(m_context, isKhr(m_data.useType));
998
999	struct TestTuple
1000	{
1001		TestTuple() {}
1002		TestTuple(deUint32 m, deUint32 n, deUint32 k) : M(m), N(n), K(k) {}
1003
1004		bool operator<(const TestTuple &other) const
1005		{
1006			return M < other.M ||
1007				   (M == other.M && N < other.N) ||
1008				   (M == other.M && N == other.N && K < other.K);
1009		}
1010
1011		deUint32 M, N, K;
1012	};
1013
1014	vector<TestTuple> testSizes;
1015
1016	if (isMatrixMulAddOp(m_data.testType))
1017	{
1018		for (size_t i = 0; i < properties.size(); ++i)
1019		{
1020			VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1021
1022			if (p->AType == m_data.inputType &&
1023				p->BType == m_data.inputType &&
1024				p->CType == m_data.outputType &&
1025				p->ResultType == m_data.outputType &&
1026				p->scope == VK_SCOPE_SUBGROUP_KHR)
1027			{
1028				testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
1029			}
1030		}
1031	}
1032	else
1033	{
1034		set<TestTuple> typeSizes[2];
1035		VkComponentTypeKHR types[2] = { m_data.inputType, m_data.outputType };
1036		const bool aType = (m_data.useType == UT_KHR_A) || (m_data.useType == UT_NV);
1037		const bool bType = (m_data.useType == UT_KHR_B) || (m_data.useType == UT_NV);
1038		const bool rType = (m_data.useType == UT_KHR_Result) || (m_data.useType == UT_NV);
1039
1040		for (deUint32 i = 0; i < properties.size(); ++i)
1041		{
1042			VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1043
1044			if (p->scope != VK_SCOPE_SUBGROUP_KHR)
1045				continue;
1046
1047			for (deUint32 j = 0; j < 2; ++j)
1048			{
1049				// For these tests, m_data.M/N are always the matrix size. Check if they match
1050				// any input or output in the list.
1051				if (aType && p->AType == types[j]) typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
1052				if (bType && p->BType == types[j]) typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
1053				if (rType && (p->CType == types[j] || p->ResultType == types[j])) typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
1054			}
1055		}
1056		// Test those sizes that are supported for both the input and output type.
1057		std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(),
1058							  typeSizes[1].begin(), typeSizes[1].end(),
1059							  std::back_inserter(testSizes));
1060	}
1061
1062	properties.resize(0);
1063
1064	for (unsigned int s = 0; s < testSizes.size(); ++s)
1065	{
1066		// When testing a multiply, MxNxK is the type of matrix multiply.
1067		// Otherwise, MxN is the size of the input/output matrices
1068		deUint32 M, N, K;
1069		M = testSizes[s].M;
1070		N = testSizes[s].N;
1071		K = testSizes[s].K;
1072
1073		log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1074
1075		struct
1076		{
1077			deUint32 rows, cols;
1078		} dims[4];
1079
1080		if (isMatrixMulAddOp(m_data.testType))
1081		{
1082			dims[0].rows = M;
1083			dims[0].cols = K;
1084			dims[1].rows = K;
1085			dims[1].cols = N;
1086			dims[2].rows = M;
1087			dims[2].cols = N;
1088			dims[3].rows = M;
1089			dims[3].cols = N;
1090		}
1091		else
1092		{
1093			dims[0].rows = M;
1094			dims[0].cols = N;
1095			dims[1].rows = M;
1096			dims[1].cols = N;
1097			dims[2].rows = M;
1098			dims[2].cols = N;
1099			dims[3].rows = M;
1100			dims[3].cols = N;
1101		}
1102
1103		VkComponentTypeKHR dataTypes[4];
1104		size_t elementSize[4];
1105		VkDeviceSize bufferSizes[5];
1106		de::MovePtr<BufferWithMemory> buffers[5];
1107		vk::VkDescriptorBufferInfo bufferDescriptors[5];
1108		deUint32 strides[4]; // in elements
1109		deUint32 loadStrides[4];
1110		deUint32 totalElements[4];
1111
1112		for (deUint32 i = 0; i < 5; ++i)
1113		{
1114			if (i < 4)
1115			{
1116				// A/B use input type, C/D use output type
1117				dataTypes[i] = (i < 2) ? m_data.inputType : m_data.outputType;
1118				elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
1119
1120				strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX * m_data.workgroupsX;
1121				loadStrides[i] = strides[i];
1122				totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) * m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
1123
1124				bufferSizes[i] = totalElements[i] * elementSize[i];
1125			}
1126			else
1127			{
1128				bufferSizes[4] = sizeof(VkDeviceAddress)*4;
1129			}
1130
1131			try
1132			{
1133				buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1134					vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
1135					(memoryDeviceAddress == MemoryRequirement::DeviceAddress ?  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
1136					MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent | memoryDeviceAddress));
1137			}
1138			catch (const tcu::NotSupportedError&)
1139			{
1140				buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1141					vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
1142					(memoryDeviceAddress == MemoryRequirement::DeviceAddress ?  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
1143					MemoryRequirement::HostVisible | memoryDeviceAddress));
1144			}
1145
1146			bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
1147		}
1148
1149		// Load with a stride of 0
1150		if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1151			loadStrides[0] = loadStrides[1] = loadStrides[2] = loadStrides[3] = 0;
1152
1153		void *ptrs[5];
1154		for (deUint32 i = 0; i < 5; ++i)
1155		{
1156			ptrs[i] = buffers[i]->getAllocation().getHostPtr();
1157		}
1158
1159		vk::DescriptorSetLayoutBuilder layoutBuilder;
1160
1161		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1162		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1163		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1164		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1165		layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1166
1167		vk::Unique<vk::VkDescriptorSetLayout>	descriptorSetLayout(layoutBuilder.build(vk, device));
1168
1169		vk::Unique<vk::VkDescriptorPool>		descriptorPool(vk::DescriptorPoolBuilder()
1170			.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
1171			.build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1172		vk::Unique<vk::VkDescriptorSet>			descriptorSet		(makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1173
1174		vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1175		if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
1176		{
1177			VkBufferDeviceAddressInfo info
1178			{
1179				VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,		// VkStructureType	 sType;
1180				DE_NULL,											// const void*		 pNext;
1181				0,													// VkBuffer			buffer
1182			};
1183			VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
1184			for (deUint32 i = 0; i < 4; ++i)
1185			{
1186				info.buffer = **buffers[i];
1187				VkDeviceAddress addr = vk.getBufferDeviceAddress(device, &info);
1188				addrsInMemory[i] = addr;
1189			}
1190			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
1191				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
1192		}
1193		else
1194		{
1195			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1196				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1197			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1198				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1199			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1200				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1201			setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
1202				VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
1203		}
1204
1205		setUpdateBuilder.update(vk, device);
1206
1207		VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1208
1209		const deUint32 specData[9] =
1210		{
1211			subgroupSize * m_data.subgroupsPerWorkgroupX,
1212			m_data.subgroupsPerWorkgroupY,
1213			strides[0],
1214			strides[1],
1215			strides[2],
1216			strides[3],
1217			M,
1218			N,
1219			K,
1220		};
1221
1222		const vk::VkSpecializationMapEntry entries[9] =
1223		{
1224			{0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1225			{1, (deUint32)(sizeof(deUint32) * 1), sizeof(deUint32)},
1226			{2, (deUint32)(sizeof(deUint32) * 2), sizeof(deUint32)},
1227			{3, (deUint32)(sizeof(deUint32) * 3), sizeof(deUint32)},
1228			{4, (deUint32)(sizeof(deUint32) * 4), sizeof(deUint32)},
1229			{5, (deUint32)(sizeof(deUint32) * 5), sizeof(deUint32)},
1230			{6, (deUint32)(sizeof(deUint32) * 6), sizeof(deUint32)},
1231			{7, (deUint32)(sizeof(deUint32) * 7), sizeof(deUint32)},
1232			{8, (deUint32)(sizeof(deUint32) * 8), sizeof(deUint32)},
1233		};
1234
1235		const vk::VkSpecializationInfo specInfo =
1236		{
1237			9,						// mapEntryCount
1238			entries,				// pMapEntries
1239			sizeof(specData),		// dataSize
1240			specData				// pData
1241		};
1242
1243		for (deUint32 i = 0; i < 4; ++i)
1244			for (deUint32 j = 0; j < totalElements[i]; ++j)
1245			{
1246				if (isFloatType(dataTypes[i]))
1247				{
1248					if (!isMatrixMulAddOp(m_data.testType))
1249						setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f)/2.0f);
1250					else
1251						setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f)/2.0f);
1252				}
1253				else
1254				{
1255					if (m_data.testType == TT_MATRIXMULADD_WRAPPING)
1256					{
1257						// Choose matrix values that should cause overflow and underflow, to
1258						// verify wrapping behavior. Use the full range of values for A and B.
1259						// For matrix C, use values clustered near where the type wraps (zero
1260						// for unsigned, 2^(N-1) for signed).
1261						deUint32 bits = componentTypeInfo[dataTypes[i]].bits;
1262						deUint32 value;
1263						if (i == 2) {
1264							value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1265							if (componentTypeInfo[dataTypes[i]].isSigned)
1266								value += (1U << (bits - 1));
1267						} else {
1268							deUint32 mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1269							value = deRandom_getUint32(&rnd) & mask;
1270						}
1271						setDataInt(ptrs[i], dataTypes[i], j, value);
1272					}
1273					else if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1274					{
1275						setDataInt(ptrs[i], dataTypes[i], j, 0);
1276					}
1277					else
1278					{
1279						deUint32 value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1280						setDataInt(ptrs[i], dataTypes[i], j, value);
1281					}
1282				}
1283			}
1284
1285		if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1286		{
1287			// Set 1st row of A to 1,0,0...
1288			setSingleElementInt(ptrs[0], dataTypes[0], 0, dims[0].cols, (m_data.colMajor ? strides[0] : 1), 0, 1);
1289
1290			// Set 1st column of B to 1,0,0...
1291			setSingleElementInt(ptrs[1], dataTypes[1], 0, dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 0, 1);
1292
1293			// Set C element at {0,0} to maximum type value, thus we will have overflow at plus operation in D=A*B+C for this element
1294			setDataInt(ptrs[2], dataTypes[2], 0, getLimit(dataTypes[2], true));
1295
1296			// Check underflow if all involved elements support negative values
1297			if (isSIntType(dataTypes[1]) && isSIntType(dataTypes[2]) && isSIntType(dataTypes[3]))
1298			{
1299				// Set 2nd row of A to 0,1,0,0...
1300				setSingleElementInt(ptrs[0], dataTypes[0], (m_data.colMajor ? 1 : strides[0]), dims[0].cols, (m_data.colMajor ? strides[0] : 1), 1, 1);
1301
1302				// Set 2nd column of B to 0,-1,0,0...
1303				setSingleElementInt(ptrs[1], dataTypes[1], (m_data.colMajor ? strides[1] : 1), dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 1, -1);
1304
1305				// Set C element at {1,1} to minimum type value, thus we will have underflow at plus operation in D=A*B+C for this element
1306				setDataInt(ptrs[2], dataTypes[2], strides[2] + 1, getLimit(dataTypes[2], false));
1307			}
1308		}
1309
1310		flushAlloc(vk, device, buffers[0]->getAllocation());
1311		flushAlloc(vk, device, buffers[1]->getAllocation());
1312		flushAlloc(vk, device, buffers[2]->getAllocation());
1313		flushAlloc(vk, device, buffers[3]->getAllocation());
1314
1315		ComputePipelineWrapper			pipeline(vk, device, m_data.computePipelineConstructionType, m_context.getBinaryCollection().get("test"));
1316		pipeline.setDescriptorSetLayout(descriptorSetLayout.get());
1317		pipeline.setSpecializationInfo(specInfo);
1318		pipeline.setSubgroupSize(m_data.subgroupSizeMode == SUBGROUP_SIZE_NONE ? 0 : getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode));
1319		pipeline.buildPipeline();
1320
1321		const VkQueue					queue					= m_context.getUniversalQueue();
1322		Move<VkCommandPool>				cmdPool					= createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
1323		Move<VkCommandBuffer>			cmdBuffer				= allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1324
1325		beginCommandBuffer(vk, *cmdBuffer, 0u);
1326
1327		vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, pipeline.getPipelineLayout(), 0u, 1, &*descriptorSet, 0u, DE_NULL);
1328		pipeline.bind(*cmdBuffer);
1329
1330		vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
1331
1332		endCommandBuffer(vk, *cmdBuffer);
1333
1334		submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1335
1336		invalidateAlloc(vk, device, buffers[3]->getAllocation());
1337
1338		qpTestResult res = QP_TEST_RESULT_PASS;
1339
1340		if (m_data.testType == TT_CONVERT)
1341		{
1342			for (deUint32 i = 0; i < totalElements[3]; ++i)
1343			{
1344				// Store results as double, which has enough range to hold all the other types exactly.
1345				double inputA, output;
1346
1347				// This loads the data according to dataTypes[0], and then converts to the template parameter type
1348				switch (dataTypes[3]) {
1349				case VK_COMPONENT_TYPE_UINT8_KHR:	inputA = getDataConvertedToT<uint8_t>(ptrs[0], dataTypes[0], i); break;
1350				case VK_COMPONENT_TYPE_UINT16_KHR:	inputA = getDataConvertedToT<uint16_t>(ptrs[0], dataTypes[0], i); break;
1351				case VK_COMPONENT_TYPE_UINT32_KHR:	inputA = getDataConvertedToT<uint32_t>(ptrs[0], dataTypes[0], i); break;
1352				case VK_COMPONENT_TYPE_SINT8_KHR:	inputA = getDataConvertedToT<int8_t>(ptrs[0], dataTypes[0], i); break;
1353				case VK_COMPONENT_TYPE_SINT16_KHR:	inputA = getDataConvertedToT<int16_t>(ptrs[0], dataTypes[0], i); break;
1354				case VK_COMPONENT_TYPE_SINT32_KHR:	inputA = getDataConvertedToT<int32_t>(ptrs[0], dataTypes[0], i); break;
1355				case VK_COMPONENT_TYPE_FLOAT32_KHR: inputA = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i); break;
1356				case VK_COMPONENT_TYPE_FLOAT16_KHR:
1357				{
1358					float temp = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
1359					inputA = deFloat16To32(deFloat32To16(temp));
1360					break;
1361				}
1362				default: TCU_THROW(InternalError, "Unexpected type");
1363				}
1364
1365				switch (dataTypes[3]) {
1366				case VK_COMPONENT_TYPE_UINT8_KHR:	output = getDataConvertedToT<uint8_t>(ptrs[3], dataTypes[3], i); break;
1367				case VK_COMPONENT_TYPE_UINT16_KHR:	output = getDataConvertedToT<uint16_t>(ptrs[3], dataTypes[3], i); break;
1368				case VK_COMPONENT_TYPE_UINT32_KHR:	output = getDataConvertedToT<uint32_t>(ptrs[3], dataTypes[3], i); break;
1369				case VK_COMPONENT_TYPE_SINT8_KHR:	output = getDataConvertedToT<int8_t>(ptrs[3], dataTypes[3], i); break;
1370				case VK_COMPONENT_TYPE_SINT16_KHR:	output = getDataConvertedToT<int16_t>(ptrs[3], dataTypes[3], i); break;
1371				case VK_COMPONENT_TYPE_SINT32_KHR:	output = getDataConvertedToT<int32_t>(ptrs[3], dataTypes[3], i); break;
1372				case VK_COMPONENT_TYPE_FLOAT32_KHR: output = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i); break;
1373				case VK_COMPONENT_TYPE_FLOAT16_KHR:
1374				{
1375					float temp = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
1376					output = deFloat16To32(deFloat32To16(temp));
1377					break;
1378				}
1379				default: TCU_THROW(InternalError, "Unexpected type");
1380				}
1381
1382				if (inputA != output) {
1383					res = QP_TEST_RESULT_FAIL;
1384					break;
1385				}
1386			}
1387		}
1388		else if (isFloatType(dataTypes[0]))
1389		{
1390			if (!isMatrixMulAddOp(m_data.testType))
1391			{
1392				for (deUint32 i = 0; i < totalElements[3]; ++i)
1393				{
1394					float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1395					float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1396					float output = getDataFloat(ptrs[3], dataTypes[3], i);
1397					switch (m_data.testType)
1398					{
1399					case TT_LENGTH:
1400						if (output < 1.0f || output > (float)(N*M))
1401							res = QP_TEST_RESULT_FAIL;
1402						// We expect the matrix to be spread evenly across invocations, it is
1403						// surprising (but not necessarily illegal) if not
1404						if (output != (float)(N*M/subgroupSize) &&
1405							res == QP_TEST_RESULT_PASS)
1406							res = QP_TEST_RESULT_QUALITY_WARNING;
1407						break;
1408					case TT_CONSTANT:
1409						if (output != 1.0f)
1410							res = QP_TEST_RESULT_FAIL;
1411						break;
1412					case TT_COMPOSITE:
1413					case TT_COMPOSITE_RVALUE:
1414					case TT_COMPOSITE_ARRAY:
1415					case TT_ADD:
1416						if (output != inputA + inputB)
1417							res = QP_TEST_RESULT_FAIL;
1418						break;
1419					case TT_SUB:
1420						if (output != inputA - inputB)
1421							res = QP_TEST_RESULT_FAIL;
1422						break;
1423					case TT_DIV:
1424						{
1425							float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR) ? 1.0f/1024.0f : 1.0f/(8.0f*1024.0f*1024.0f);
1426							// division allows 2.5ulp, but we'll use 3.
1427							ulp *= 3;
1428							if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1429								res = QP_TEST_RESULT_FAIL;
1430						}
1431						break;
1432					case TT_MUL:
1433					{
1434						if (dataTypes[0] == VK_COMPONENT_TYPE_FLOAT16_KHR)
1435						{
1436							const float		expected32	= inputA * inputB;
1437							const deFloat16	expected16	= deFloat32To16(expected32);
1438							const float		expected	= deFloat16To32(expected16);
1439
1440							if (output != expected)
1441								res = QP_TEST_RESULT_FAIL;
1442						}
1443						else
1444						{
1445							if (output != inputA * inputB)
1446								res = QP_TEST_RESULT_FAIL;
1447						}
1448						break;
1449					}
1450					case TT_NEGATE:
1451					case TT_FUNC:
1452						if (output != -inputA)
1453							res = QP_TEST_RESULT_FAIL;
1454						break;
1455					case TT_MATRIXTIMESSCALAR:
1456						if (output != 6.0*inputA)
1457							res = QP_TEST_RESULT_FAIL;
1458						break;
1459					default:
1460						break;
1461					}
1462				}
1463			}
1464			else
1465			{
1466				deUint32 ik, kj, ij;
1467				for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1468				{
1469					for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1470					{
1471						for (deUint32 i = 0; i < M; ++i)
1472						{
1473							for (deUint32 j = 0; j < N; ++j)
1474							{
1475								float ref = 0;
1476								for (deUint32 k = 0; k < K; ++k)
1477								{
1478									if (m_data.colMajor)
1479										ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1480									else
1481										ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1482
1483									float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1484
1485									if (m_data.colMajor)
1486										kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1487									else
1488										kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1489
1490									float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1491
1492									ref += Aik*Bkj;
1493								}
1494
1495								if (m_data.colMajor)
1496									ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1497								else
1498									ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1499
1500								float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1501
1502								ref += Cij;
1503
1504								// When loading with stride 0, ij for matrix D is different from matrix C
1505								if (m_data.colMajor)
1506									ij = mX * M + i + strides[2] * (mY * N + j);
1507								else
1508									ij = mX * N + j + strides[2] * (mY * M + i);
1509
1510								float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1511
1512								if (fabs(ref - Dij) > epsilon)
1513								{
1514									res = QP_TEST_RESULT_FAIL;
1515								}
1516							}
1517						}
1518					}
1519				}
1520			}
1521		} else {
1522			if (!isMatrixMulAddOp(m_data.testType))
1523			{
1524				for (deUint32 i = 0; i < totalElements[3]; ++i)
1525				{
1526					deUint32 inputA = getDataInt(ptrs[0], dataTypes[0], i);
1527					deUint32 inputB = getDataInt(ptrs[1], dataTypes[1], i);
1528					deUint32 output = getDataInt(ptrs[3], dataTypes[3], i);
1529					int resultSize = componentTypeInfo[dataTypes[3]].bits;
1530					deUint32 mask = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1531					switch (m_data.testType)
1532					{
1533					case TT_LENGTH:
1534						if (output < 1 || output > N*M)
1535							res = QP_TEST_RESULT_FAIL;
1536						// We expect the matrix to be spread evenly across invocations, it is
1537						// surprising (but not necessarily illegal) if not
1538						if (output != N*M/subgroupSize &&
1539							res == QP_TEST_RESULT_PASS)
1540							res = QP_TEST_RESULT_QUALITY_WARNING;
1541						break;
1542					case TT_CONSTANT:
1543						if (output != 1)
1544							res = QP_TEST_RESULT_FAIL;
1545						break;
1546					case TT_COMPOSITE:
1547					case TT_COMPOSITE_RVALUE:
1548					case TT_COMPOSITE_ARRAY:
1549					case TT_ADD:
1550						if ((output & mask) != ((inputA + inputB) & mask)) {
1551							res = QP_TEST_RESULT_FAIL;
1552						}
1553						break;
1554					case TT_SUB:
1555						if ((output & mask) != ((inputA - inputB) & mask))
1556							res = QP_TEST_RESULT_FAIL;
1557						break;
1558					case TT_DIV:
1559						{
1560							if (isSIntType(dataTypes[3]))
1561							{
1562								if (inputB != 0 && ((deInt32)output & mask) != (((deInt32)inputA / (deInt32)inputB) & mask))
1563									res = QP_TEST_RESULT_FAIL;
1564							} else
1565							{
1566								if (inputB != 0 && output != inputA / inputB)
1567									res = QP_TEST_RESULT_FAIL;
1568							}
1569						}
1570						break;
1571					case TT_MUL:
1572					{
1573						if (((deInt32)output & mask) != (((deInt32)inputA * (deInt32)inputB) & mask))
1574						{
1575							res = QP_TEST_RESULT_FAIL;
1576						}
1577
1578						break;
1579					}
1580					case TT_NEGATE:
1581					case TT_FUNC:
1582						if ((output & mask) != ((-(deInt32)inputA) & mask))
1583							res = QP_TEST_RESULT_FAIL;
1584						break;
1585					case TT_MATRIXTIMESSCALAR:
1586						if ((output & mask) != ((6*inputA) & mask)) {
1587							res = QP_TEST_RESULT_FAIL;
1588						}
1589						break;
1590					default:
1591						break;
1592					}
1593				}
1594			}
1595			else
1596			{
1597				deUint32 ik, kj, ij;
1598				for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1599				{
1600					for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1601					{
1602						for (deUint32 i = 0; i < M; ++i)
1603						{
1604							for (deUint32 j = 0; j < N; ++j)
1605							{
1606								deUint32 ref = 0;
1607
1608								for (deUint32 k = 0; k < K; ++k)
1609								{
1610									if (m_data.colMajor)
1611										ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1612									else
1613										ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1614
1615									deUint32 Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1616
1617									if (m_data.colMajor)
1618										kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1619									else
1620										kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1621
1622									deUint32 Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1623
1624									ref += Aik*Bkj;
1625								}
1626
1627								if (m_data.colMajor)
1628									ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1629								else
1630									ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1631
1632								deUint32 Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1633
1634								if (saturated)
1635								{
1636									ref = satAddData(dataTypes[2], ref, Cij);
1637								}
1638								else
1639								{
1640									ref += Cij;
1641									// truncate the result to the size of C's type.
1642									deUint32 bits = componentTypeInfo[dataTypes[3]].bits;
1643									deUint32 mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1644									ref &= mask;
1645								}
1646
1647								// When loading with stride 0, ij for matrix D is different from matrix C
1648								if (m_data.colMajor)
1649									ij = mX * M + i + strides[2] * (mY * N + j);
1650								else
1651									ij = mX * N + j + strides[2] * (mY * M + i);
1652
1653								deUint32 Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1654
1655								if (ref != Dij)
1656								{
1657									res = QP_TEST_RESULT_FAIL;
1658								}
1659							}
1660						}
1661					}
1662				}
1663			}
1664		}
1665
1666		if (res != QP_TEST_RESULT_PASS)
1667		{
1668			finalres = res;
1669
1670			log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1671
1672#ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
1673			for (int i = 0; i < 4; i++)
1674			{
1675				const char* matrixNames[] = { "A", "B", "C", "D" };
1676
1677				log << tcu::TestLog::Message
1678					<< "Matrix " << matrixNames[i]
1679					<< "[rows="
1680					<< m_data.subgroupsPerWorkgroupY * m_data.workgroupsY * dims[i].rows
1681					<< ", cols="
1682					<< m_data.subgroupsPerWorkgroupX * m_data.workgroupsX * dims[i].cols << "]:\n"
1683					<< dumpWholeMatrix(ptrs[i], dataTypes[i], m_data.colMajor, totalElements[i], strides[i])
1684					<< tcu::TestLog::EndMessage;
1685			}
1686#endif
1687		}
1688		else
1689		{
1690			if (finalres == QP_TEST_RESULT_NOT_SUPPORTED)
1691				finalres = res;
1692		}
1693	}
1694
1695	return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1696}
1697
1698const char* getUseType (UseType useType)
1699{
1700	switch (useType)
1701	{
1702		case UT_NV:			return "nv";
1703		case UT_KHR_A:		return "khr_a";
1704		case UT_KHR_B:		return "khr_b";
1705		case UT_KHR_Result:	return "khr_r";
1706		default:			TCU_THROW(InternalError, "Unknown use type");
1707	}
1708}
1709
1710tcu::TestCaseGroup*	createCooperativeMatrixTestsInternal (tcu::TestContext& testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType, UseType useType)
1711{
1712	de::MovePtr<tcu::TestCaseGroup> group	(new tcu::TestCaseGroup(testCtx, getUseType(useType)));
1713
1714	typedef struct
1715	{
1716		deUint32				value;
1717		const char*				name;
1718	} TestGroupCase;
1719
1720	typedef struct
1721	{
1722		deUint32				value[2];
1723		const char*				name;
1724	} TestGroupCase2;
1725
1726	typedef struct
1727	{
1728		SubgroupSizeMode		value;
1729		const char*				name;
1730	} SubGroubSizes;
1731
1732	TestGroupCase ttCases[] =
1733	{
1734		// OpCooperativeMatrixLength
1735		{ TT_LENGTH,				"length"},
1736		// OpConstantComposite
1737		{ TT_CONSTANT,				"constant"},
1738		// OpCompositeConstruct
1739		{ TT_COMPOSITE,				"composite"},
1740		// OpCompositeExtract
1741		{ TT_COMPOSITE_RVALUE,		"composite_rvalue"},
1742		// OpFAdd/OpIAdd
1743		{ TT_ADD,					"add"},
1744		// OpFSub/OpISub
1745		{ TT_SUB,					"sub"},
1746		// OpFDiv/OpSDiv/OpUDiv
1747		{ TT_DIV,					"div"},
1748		// OpFMul/OpIMul
1749		{ TT_MUL,					"mul"},
1750		// OpFNegate/OpSNegate
1751		{ TT_NEGATE,				"negate"},
1752		// OpMatrixTimesScalar
1753		{ TT_MATRIXTIMESSCALAR,		"matrixtimesscalar"},
1754		// OpFunctionParameter
1755		{ TT_FUNC,					"func"},
1756		// OpCooperativeMatrixMulAdd
1757		{ TT_MATRIXMULADD,			"matrixmuladd"},
1758		// OpCompositeConstruct w/array
1759		{ TT_COMPOSITE_ARRAY,		"composite_array"},
1760		// OpCooperativeMatrixMulAdd w/array
1761		{ TT_MATRIXMULADD_ARRAY,	"matrixmuladd_array"},
1762		// OpCooperativeMatrixMulAdd w/saturations
1763		{ TT_MATRIXMULADD_SATURATED,"matrixmuladd_saturated"},
1764		// OpCooperativeMatrixMulAdd w/wrapping
1765		{ TT_MATRIXMULADD_WRAPPING,	"matrixmuladd_wrapping"},
1766		// OpCooperativeMatrixMulAdd w/stride==0
1767		{ TT_MATRIXMULADD_STRIDE0,	"matrixmuladd_stride0"},
1768	};
1769	TestGroupCase2 dtCases[] =
1770	{
1771		// A/B are fp32 C/D are fp32
1772		{ { VK_COMPONENT_TYPE_FLOAT32_KHR,	VK_COMPONENT_TYPE_FLOAT32_KHR },	"float32_float32"},
1773		// A/B are fp32 C/D are fp16
1774		{ { VK_COMPONENT_TYPE_FLOAT32_KHR,	VK_COMPONENT_TYPE_FLOAT16_KHR },	"float32_float16"},
1775		// A/B are fp16 C/D are fp32
1776		{ { VK_COMPONENT_TYPE_FLOAT16_KHR,	VK_COMPONENT_TYPE_FLOAT32_KHR },	"float16_float32"},
1777		// A/B are fp16 C/D are fp16
1778		{ { VK_COMPONENT_TYPE_FLOAT16_KHR,	VK_COMPONENT_TYPE_FLOAT16_KHR },	"float16_float16"},
1779		// A/B are u8 C/D are u8
1780		{ { VK_COMPONENT_TYPE_UINT8_KHR,	VK_COMPONENT_TYPE_UINT8_KHR },		"uint8_uint8"},
1781		// A/B are u8 C/D are u32
1782		{ { VK_COMPONENT_TYPE_UINT8_KHR,	VK_COMPONENT_TYPE_UINT32_KHR },		"uint8_uint32"},
1783		// A/B are s8 C/D are s8
1784		{ { VK_COMPONENT_TYPE_SINT8_KHR,	VK_COMPONENT_TYPE_SINT8_KHR },		"sint8_sint8"},
1785		// A/B are s8 C/D are s32
1786		{ { VK_COMPONENT_TYPE_SINT8_KHR,	VK_COMPONENT_TYPE_SINT32_KHR },		"sint8_sint32"},
1787		// A/B are u8 C/D are s32
1788		{ { VK_COMPONENT_TYPE_UINT8_KHR,	VK_COMPONENT_TYPE_SINT32_KHR },		"uint8_sint32"},
1789		// A/B are u32 C/D are u32
1790		{ { VK_COMPONENT_TYPE_UINT32_KHR,	VK_COMPONENT_TYPE_UINT32_KHR },		"uint32_uint32"},
1791		// A/B are u32 C/D are u8
1792		{ { VK_COMPONENT_TYPE_UINT32_KHR,	VK_COMPONENT_TYPE_UINT8_KHR },		"uint32_uint8"},
1793		// A/B are s32 C/D are s32
1794		{ { VK_COMPONENT_TYPE_SINT32_KHR,	VK_COMPONENT_TYPE_SINT32_KHR },		"sint32_sint32"},
1795		// A/B are s32 C/D are s8
1796		{ { VK_COMPONENT_TYPE_SINT32_KHR,	VK_COMPONENT_TYPE_SINT8_KHR },		"sint32_sint8"},
1797	};
1798	SubGroubSizes sgsCases[] =
1799	{
1800		// Default subgroup size
1801		{ SUBGROUP_SIZE_NONE,	"" },
1802		// Minimum subgroup size
1803		{ SUBGROUP_SIZE_MIN,	"_min"},
1804		// Maximum subgroup size
1805		{ SUBGROUP_SIZE_MAX,	"_max"},
1806	};
1807
1808	TestGroupCase colCases[] =
1809	{
1810		{ 0,		"rowmajor"},
1811		{ 1,		"colmajor"},
1812	};
1813
1814	TestGroupCase scCases[] =
1815	{
1816		// SSBO
1817		{ SC_BUFFER,						"buffer"},
1818		// shared memory
1819		{ SC_WORKGROUP,						"workgroup"},
1820		// SSBO w/variable pointers
1821		{ SC_BUFFER_VARIABLE_POINTERS,		"buffer_varptr"},
1822		// shared memory w/variable pointers
1823		{ SC_WORKGROUP_VARIABLE_POINTERS,	"workgroup_varptr"},
1824		// physical_storage_buffer
1825		{ SC_PHYSICAL_STORAGE_BUFFER,		"physical_buffer"},
1826	};
1827
1828	// Types tested for conversions. Excludes 64b types.
1829	VkComponentTypeKHR allTypes[] =
1830	{
1831		VK_COMPONENT_TYPE_FLOAT16_KHR,
1832		VK_COMPONENT_TYPE_FLOAT32_KHR,
1833		VK_COMPONENT_TYPE_SINT8_KHR,
1834		VK_COMPONENT_TYPE_SINT16_KHR,
1835		VK_COMPONENT_TYPE_SINT32_KHR,
1836		VK_COMPONENT_TYPE_UINT8_KHR,
1837		VK_COMPONENT_TYPE_UINT16_KHR,
1838		VK_COMPONENT_TYPE_UINT32_KHR,
1839	};
1840
1841	for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1842	{
1843		const TestType	testType = (TestType)ttCases[ttNdx].value;
1844
1845		for (int sgsNdx = 0; sgsNdx < DE_LENGTH_OF_ARRAY(sgsCases); sgsNdx++)
1846		{
1847			if (testType != TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE)
1848				continue;
1849
1850			if (testType == TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE && useType == UT_NV)
1851				continue;
1852
1853			const string					name	= string(ttCases[ttNdx].name) + sgsCases[sgsNdx].name;
1854			de::MovePtr<tcu::TestCaseGroup>	ttGroup	(new tcu::TestCaseGroup(testCtx, name.c_str()));
1855
1856			for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
1857			{
1858				de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name));
1859				for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1860				{
1861					de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
1862					for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1863					{
1864						const VkComponentTypeKHR	inputType = (VkComponentTypeKHR)dtCases[dtNdx].value[0];
1865						const VkComponentTypeKHR	outputType = (VkComponentTypeKHR)dtCases[dtNdx].value[1];
1866						const bool					isMatrixMul = isMatrixMulAddOp(testType);
1867
1868						// useType isn't used for matrixmul shaders. Don't generate 3 copies of those tests.
1869						if (isMatrixMul && (useType == UT_KHR_A || useType == UT_KHR_B)) {
1870							continue;
1871						}
1872
1873						// NV extension doesn't support mixing signedness
1874						if (isMatrixMul && (useType == UT_NV) && isSIntType(inputType) != isSIntType(outputType)) {
1875							continue;
1876						}
1877
1878						if (!isMatrixMul && inputType != outputType)
1879							continue;
1880
1881						if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
1882							continue;
1883
1884						if (testType == TT_MUL && useType == UT_NV)
1885							continue;
1886
1887						if (testType == TT_MATRIXMULADD_SATURATED && (isFloatType(inputType) || useType == UT_NV))
1888							continue;
1889
1890						if (testType == TT_MATRIXMULADD_WRAPPING && (isFloatType(inputType) || useType == UT_NV))
1891							continue;
1892
1893						if (testType == TT_MATRIXMULADD_STRIDE0 && useType == UT_NV)
1894							continue;
1895
1896						if (testType == TT_LENGTH && useType != UT_NV && (outputType == VK_COMPONENT_TYPE_SINT8_KHR || outputType == VK_COMPONENT_TYPE_UINT8_KHR))
1897							continue;
1898
1899						CaseDef c =
1900						{
1901							testType,							//  TestType							testtype;
1902							2u,									//  deUint32							subgroupsPerWorkgroupX;
1903							2u,									//  deUint32							subgroupsPerWorkgroupY;
1904							4u,									//  deUint32							workgroupsX;
1905							4u,									//  deUint32							workgroupsY;
1906							inputType,							//  VkComponentTypeKHR					inputType;
1907							outputType,							//  VkComponentTypeKHR					outputType;
1908							!!colCases[colNdx].value,			//  bool								colMajor;
1909							(StorageClass)scCases[scNdx].value,	//  StorageClass						storageClass;
1910							useType,							//  UseType								useType;
1911							sgsCases[sgsNdx].value,				//  SubgroupSizeMode					subgroupSizeMode;
1912							computePipelineConstructionType,	//  vk::ComputePipelineConstructionType	computePipelineConstructionType;
1913						};
1914
1915						scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
1916					}
1917					dtGroup->addChild(scGroup.release());
1918				}
1919				ttGroup->addChild(dtGroup.release());
1920			}
1921			group->addChild(ttGroup.release());
1922		}
1923	}
1924
1925	{
1926		const string					name	= string("convert");
1927		const string					desc	= string("OpFConvert/OpSConvert/OpUConvert/OpBitcast");
1928		de::MovePtr<tcu::TestCaseGroup>	ttGroup	(new tcu::TestCaseGroup(testCtx, name.c_str()));
1929
1930		for (int dtNdx1 = 0; dtNdx1 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx1++)
1931		{
1932			for (int dtNdx2 = 0; dtNdx2 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx2++)
1933			{
1934				const VkComponentTypeKHR	inputType = (VkComponentTypeKHR)allTypes[dtNdx1];
1935				const VkComponentTypeKHR	outputType = (VkComponentTypeKHR)allTypes[dtNdx2];
1936				const string			name2	= string("input_") + string(componentTypeInfo[inputType].typeName) + string("_output_") + string(componentTypeInfo[outputType].typeName);
1937				de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name2.c_str()));
1938				for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1939				{
1940					de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
1941					for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1942					{
1943
1944						CaseDef c =
1945						{
1946							TT_CONVERT,							//  TestType							testtype;
1947							2u,									//  deUint32							subgroupsPerWorkgroupX;
1948							2u,									//  deUint32							subgroupsPerWorkgroupY;
1949							4u,									//  deUint32							workgroupsX;
1950							4u,									//  deUint32							workgroupsY;
1951							inputType,							//  VkComponentTypeKHR					inputType;
1952							outputType,							//  VkComponentTypeKHR					outputType;
1953							!!colCases[colNdx].value,			//  bool								colMajor;
1954							(StorageClass)scCases[scNdx].value,	//  StorageClass						storageClass;
1955							useType,							//  UseType								useType;
1956							SUBGROUP_SIZE_NONE,					//  SubgroupSizeMode					subgroupSizeMode;
1957							computePipelineConstructionType,	//  vk::ComputePipelineConstructionType	computePipelineConstructionType;
1958						};
1959
1960						scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
1961					}
1962					dtGroup->addChild(scGroup.release());
1963				}
1964				ttGroup->addChild(dtGroup.release());
1965			}
1966		}
1967		group->addChild(ttGroup.release());
1968	}
1969
1970	return group.release();
1971}
1972
1973}	// anonymous
1974
1975tcu::TestCaseGroup* createCooperativeMatrixTests (tcu::TestContext& testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType)
1976{
1977	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "cooperative_matrix"));
1978
1979	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_NV));
1980	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_A));
1981	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_B));
1982	group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_Result));
1983
1984	return group.release();
1985}
1986
1987}	// compute
1988}	// vkt
1989