1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2018-2019 NVIDIA Corporation
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Vulkan Cooperative Matrix tests
23 *//*--------------------------------------------------------------------*/
24
25 #include "vktComputeCooperativeMatrixTests.hpp"
26
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
34
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
37
38 #include "deDefs.h"
39 #include "deFloat16.h"
40 #include "deMath.h"
41 #include "deRandom.h"
42 #include "deSharedPtr.hpp"
43 #include "deString.h"
44
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
47
48 #include <string>
49 #include <sstream>
50 #include <set>
51 #include <algorithm>
52
53 namespace vkt
54 {
55 namespace compute
56 {
57 namespace
58 {
59 using namespace vk;
60 using namespace std;
61
62 typedef enum
63 {
64 TT_LENGTH = 0,
65 TT_CONSTANT,
66 TT_CONVERT,
67 TT_COMPOSITE,
68 TT_COMPOSITE_RVALUE,
69 TT_ADD,
70 TT_SUB,
71 TT_DIV,
72 TT_NEGATE,
73 TT_MATRIXTIMESSCALAR,
74 TT_FUNC,
75 TT_MATRIXMULADD,
76 TT_COMPOSITE_ARRAY,
77 TT_MATRIXMULADD_ARRAY,
78 } TestType;
79
80 typedef enum
81 {
82 SC_BUFFER = 0,
83 SC_WORKGROUP,
84 SC_WORKGROUP_VARIABLE_POINTERS,
85 SC_BUFFER_VARIABLE_POINTERS,
86 SC_PHYSICAL_STORAGE_BUFFER,
87 } StorageClass;
88
89 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
90
91 struct CaseDef
92 {
93 TestType testType;
94 deUint32 subgroupsPerWorkgroupX;
95 deUint32 subgroupsPerWorkgroupY;
96 deUint32 workgroupsX;
97 deUint32 workgroupsY;
98 VkComponentTypeNV inputType;
99 VkComponentTypeNV outputType;
100 bool colMajor;
101 StorageClass storageClass;
102 };
103
104 class CooperativeMatrixTestInstance : public TestInstance
105 {
106 public:
107 CooperativeMatrixTestInstance (Context& context, const CaseDef& data);
108 ~CooperativeMatrixTestInstance (void);
109 tcu::TestStatus iterate (void);
110 private:
111 CaseDef m_data;
112 };
113
CooperativeMatrixTestInstance(Context& context, const CaseDef& data)114 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance (Context& context, const CaseDef& data)
115 : vkt::TestInstance (context)
116 , m_data (data)
117 {
118 }
119
~CooperativeMatrixTestInstance(void)120 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance (void)
121 {
122 }
123
124 class CooperativeMatrixTestCase : public TestCase
125 {
126 public:
127 CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
128 ~CooperativeMatrixTestCase (void);
129 virtual void initPrograms (SourceCollections& programCollection) const;
130 virtual TestInstance* createInstance (Context& context) const;
131 virtual void checkSupport (Context& context) const;
132
133 private:
134 CaseDef m_data;
135 };
136
CooperativeMatrixTestCase(tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)137 CooperativeMatrixTestCase::CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
138 : vkt::TestCase (context, name, desc)
139 , m_data (data)
140 {
141 }
142
~CooperativeMatrixTestCase(void)143 CooperativeMatrixTestCase::~CooperativeMatrixTestCase (void)
144 {
145 }
146
checkSupport(Context& context) const147 void CooperativeMatrixTestCase::checkSupport(Context& context) const
148 {
149 if (!context.contextSupports(vk::ApiVersion(0, 1, 1, 0)))
150 {
151 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
152 }
153
154 if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
155 {
156 TCU_THROW(NotSupportedError, "cooperativeMatrix not supported");
157 }
158
159 if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
160 {
161 TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
162 }
163
164 if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
165 !context.getVariablePointersFeatures().variablePointers)
166 {
167 TCU_THROW(NotSupportedError, "variable pointers not supported");
168 }
169
170 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
171 {
172 TCU_THROW(NotSupportedError, "buffer device address not supported");
173 }
174
175 if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
176 (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_NV || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_NV))
177 {
178 TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
179 }
180
181 deUint32 propertyCount = 0;
182 VkCooperativeMatrixPropertiesNV *pProperties;
183 context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(context.getPhysicalDevice(), &propertyCount, DE_NULL);
184 if (propertyCount == 0)
185 TCU_THROW(NotSupportedError, "cooperative matrices not supported");
186
187 bool supported[2] = { false, false };
188 pProperties = new VkCooperativeMatrixPropertiesNV[propertyCount];
189
190 for (deUint32 i = 0; i < propertyCount; ++i)
191 {
192 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
193 p->sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_NV;
194 p->pNext = DE_NULL;
195 }
196
197 context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(context.getPhysicalDevice(), &propertyCount, pProperties);
198
199 for (deUint32 i = 0; i < propertyCount; ++i)
200 {
201 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
202 if (m_data.testType == TT_MATRIXMULADD ||
203 m_data.testType == TT_MATRIXMULADD_ARRAY)
204 {
205 if (p->AType == m_data.inputType &&
206 p->BType == m_data.inputType &&
207 p->CType == m_data.outputType &&
208 p->DType == m_data.outputType &&
209 p->scope == VK_SCOPE_SUBGROUP_NV)
210 {
211 supported[0] = supported[1] = true;
212 }
213 }
214 else
215 {
216 VkComponentTypeNV types[2] = { m_data.inputType, m_data.outputType };
217
218 for (deUint32 j = 0; j < 2; ++j)
219 {
220 if (p->scope == VK_SCOPE_SUBGROUP_NV && (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] || p->DType == types[j]))
221 {
222 supported[j] = true;
223 }
224 }
225 }
226 }
227
228 delete [] pProperties;
229
230 if (!supported[0] || !supported[1])
231 TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
232 }
233
234 struct {
235 const char *typeName;
236 const char *coopmatTypeName;
237 deUint32 bits;
238 } componentTypeInfo[] =
239 {
240 { "float16_t", "fcoopmatNV", 16 },
241 { "float32_t", "fcoopmatNV", 32 },
242 { "float64_t", "fcoopmatNV", 64 },
243 { "int8_t", "icoopmatNV", 8 },
244 { "int16_t", "icoopmatNV", 16 },
245 { "int32_t", "icoopmatNV", 32 },
246 { "int64_t", "icoopmatNV", 64 },
247 { "uint8_t", "ucoopmatNV", 8 },
248 { "uint16_t", "ucoopmatNV", 16 },
249 { "uint32_t", "ucoopmatNV", 32 },
250 { "uint64_t", "ucoopmatNV", 64 },
251 };
252
isFloatType(VkComponentTypeNV t)253 static bool isFloatType(VkComponentTypeNV t)
254 {
255 switch (t)
256 {
257 default:
258 return false;
259 case VK_COMPONENT_TYPE_FLOAT16_NV:
260 case VK_COMPONENT_TYPE_FLOAT32_NV:
261 case VK_COMPONENT_TYPE_FLOAT64_NV:
262 return true;
263 }
264 }
265
isSIntType(VkComponentTypeNV t)266 static bool isSIntType(VkComponentTypeNV t)
267 {
268 switch (t)
269 {
270 default:
271 return false;
272 case VK_COMPONENT_TYPE_SINT8_NV:
273 case VK_COMPONENT_TYPE_SINT16_NV:
274 case VK_COMPONENT_TYPE_SINT32_NV:
275 case VK_COMPONENT_TYPE_SINT64_NV:
276 return true;
277 }
278 }
279
initPrograms(SourceCollections& programCollection) const280 void CooperativeMatrixTestCase::initPrograms (SourceCollections& programCollection) const
281 {
282 std::stringstream css;
283 css << "#version 450 core\n";
284 css << "#pragma use_vulkan_memory_model\n";
285 css <<
286 "#extension GL_KHR_shader_subgroup_basic : enable\n"
287 "#extension GL_KHR_memory_scope_semantics : enable\n"
288 "#extension GL_NV_cooperative_matrix : enable\n"
289 "#extension GL_NV_integer_cooperative_matrix : enable\n"
290 "#extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable\n"
291 "#extension GL_EXT_shader_explicit_arithmetic_types_float32 : enable\n"
292 "#extension GL_EXT_shader_explicit_arithmetic_types_int8 : enable\n"
293 "#extension GL_EXT_shader_explicit_arithmetic_types_int32 : enable\n"
294 "#extension GL_EXT_buffer_reference : enable\n"
295 "// strides overriden by spec constants\n"
296 "layout(constant_id = 2) const int AStride = 1;\n"
297 "layout(constant_id = 3) const int BStride = 1;\n"
298 "layout(constant_id = 4) const int CStride = 1;\n"
299 "layout(constant_id = 5) const int OStride = 1;\n"
300 "layout(constant_id = 6) const int M = 1;\n"
301 "layout(constant_id = 7) const int N = 1;\n"
302 "layout(constant_id = 8) const int K = 1;\n"
303 "layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
304
305 if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
306 css << "#pragma use_variable_pointers\n";
307
308 struct
309 {
310 string rows, cols;
311 } dims[4];
312
313 if (m_data.testType == TT_MATRIXMULADD ||
314 m_data.testType == TT_MATRIXMULADD_ARRAY)
315 {
316 dims[0].rows = "M";
317 dims[0].cols = "K";
318 dims[1].rows = "K";
319 dims[1].cols = "N";
320 dims[2].rows = "M";
321 dims[2].cols = "N";
322 dims[3].rows = "M";
323 dims[3].cols = "N";
324 }
325 else
326 {
327 dims[0].rows = "M";
328 dims[0].cols = "N";
329 dims[1].rows = "M";
330 dims[1].cols = "N";
331 dims[2].rows = "M";
332 dims[2].cols = "N";
333 dims[3].rows = "M";
334 dims[3].cols = "N";
335 }
336
337 const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
338 const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
339 const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
340 const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
341
342 css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
343 css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", " << m_data.subgroupsPerWorkgroupY << ");\n";
344
345 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
346 {
347 css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
348 css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
349 css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
350 css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
351 css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; } params;\n";
352 }
353 else
354 {
355 css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
356 css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
357 css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
358 css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
359 }
360
361 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
362 {
363 css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
364 css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
365 css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
366 css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
367 }
368
369 std::stringstream matAType, matBType, matCType, outputMatType;
370
371 matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ">";
372 matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ">";
373 matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ">";
374 outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ">";
375
376 css << matAType.str() << " matA;\n";
377 css << matBType.str() << " matB;\n";
378 css << matCType.str() << " matC;\n";
379 css << outputMatType.str() << " matO;\n";
380
381 if (m_data.testType == TT_CONSTANT)
382 css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
383
384 if (m_data.testType == TT_FUNC)
385 css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
386
387 css <<
388 "void main()\n"
389 "{\n"
390 // matrixID is the x,y index of the matrix owned by this subgroup.
391 " uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
392 " uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
393
394 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
395 {
396 css << " InputA inputA = params.inputA;\n";
397 css << " InputB inputB = params.inputB;\n";
398 css << " InputC inputC = params.inputC;\n";
399 css << " Output outputO = params.outputO;\n";
400 }
401
402 string strides[4];
403 for (deUint32 i = 0; i < 4; ++i)
404 {
405 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") + de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
406 }
407
408 // element<i> is the starting element in buffer memory.
409 // elementS<i> is the starting element in shared memory.
410 css << " uint element0 = " << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x;\n"
411 " uint element1 = " << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x;\n"
412 " uint element2 = " << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x;\n"
413 " uint element3 = " << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x;\n"
414 " uint elementS0, elementS1, elementS2, elementS3;\n";
415
416 // For shared memory tests, copy the matrix from buffer memory into
417 // workgroup memory. For simplicity, do it all on a single thread.
418 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
419 {
420 const char *name[] =
421 {
422 "sharedA",
423 "sharedB",
424 "sharedC",
425 };
426 const char *inputName[] =
427 {
428 "inputA",
429 "inputB",
430 "inputC",
431 };
432 for (deUint32 m = 0; m < 4; ++m)
433 {
434 string sharedStride = strides[m] + " / workgroupsX";
435 css << " elementS" << m << " = " << sharedStride << " * " << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + " << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x;\n";
436 }
437 css << " if (subgroupElect()) {\n";
438 // copy all three input buffers.
439 for (deUint32 m = 0; m < 3; ++m)
440 {
441 string sharedStride = strides[m] + " / workgroupsX";
442 css << " for (int i = 0; i < " << dims[m].rows << "; ++i) {\n"
443 " for (int j = 0; j < " << dims[m].cols << "; ++j) {\n"
444 " int localElementInput = " << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
445 " int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
446 " " << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m << " + localElementInput];\n"
447 " }\n"
448 " }\n";
449 strides[m] = sharedStride;
450 }
451 css << " }\n";
452 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
453 }
454
455 const char *colMajor = (m_data.colMajor ? "true" : "false");
456
457 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
458 {
459 css << " coopMatLoadNV(matA, sharedA, elementS0, " << strides[0] << ", " << colMajor << ");\n"
460 " coopMatLoadNV(matB, sharedB, elementS1, " << strides[1] << ", " << colMajor << ");\n"
461 " coopMatLoadNV(matC, sharedC, elementS2, " << strides[2] << ", " << colMajor << ");\n";
462 }
463 else
464 {
465 css << " coopMatLoadNV(matA, inputA.x, element0, " << strides[0] << ", " << colMajor << ");\n"
466 " coopMatLoadNV(matB, inputB.x, element1, " << strides[1] << ", " << colMajor << ");\n"
467 " coopMatLoadNV(matC, inputC.x, element2, " << strides[2] << ", " << colMajor << ");\n";
468 }
469
470 if (m_data.testType == TT_COMPOSITE_ARRAY ||
471 m_data.testType == TT_MATRIXMULADD_ARRAY)
472 {
473 css << " " << matAType.str() << " matAArr[2];\n matAArr[1] = matA; matAArr[0] = " << matAType.str() << "(0.0);\n"
474 " " << matBType.str() << " matBArr[2];\n matBArr[1] = matB; matBArr[0] = " << matBType.str() << "(0.0);\n"
475 " " << matCType.str() << " matCArr[2];\n matCArr[1] = matC; matCArr[0] = " << matCType.str() << "(0.0);\n"
476 " " << outputMatType.str() << " matOArr[2];\n";
477 }
478
479 switch (m_data.testType)
480 {
481 default:
482 DE_ASSERT(0);
483 // fall through
484 case TT_LENGTH:
485 css << " matO = " << outputMatType.str() << "(matO.length());\n";
486 break;
487 case TT_CONSTANT:
488 css << " matO = matConst;\n";
489 break;
490 case TT_CONVERT:
491 css << " matO = " << outputMatType.str() << "(matA);\n";
492 break;
493 case TT_COMPOSITE:
494 case TT_COMPOSITE_RVALUE:
495 css << " for (int i = 0; i < matA.length(); ++i) {\n"
496 " matO[i] = matA[i] + matB[i];\n"
497 " }\n";
498 if (m_data.testType == TT_COMPOSITE_RVALUE)
499 {
500 css << " " << matAType.str() << " t = matA;\n"
501 " matO[0] = (t += matB)[0];\n"
502 " if (matA.length() > 0) {\n"
503 " t = matA;\n"
504 " matO[1] = (t += matB)[1];\n"
505 " }\n";
506 }
507 break;
508 case TT_COMPOSITE_ARRAY:
509 css << " for (int i = 0; i < matA.length(); ++i) {\n"
510 " matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
511 " }\n";
512 break;
513 case TT_ADD:
514 css << " matO = matA + matB;\n";
515 break;
516 case TT_SUB:
517 css << " matO = matA - matB;\n";
518 break;
519 case TT_DIV:
520 css << " matO = matA / matB;\n";
521 break;
522 case TT_NEGATE:
523 css << " matO = -matA;\n";
524 break;
525 case TT_FUNC:
526 css << " matO = f(matA);\n";
527 break;
528 case TT_MATRIXTIMESSCALAR:
529 css << " matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
530 break;
531 case TT_MATRIXMULADD:
532 css << " matO = coopMatMulAddNV(matA, matB, matC);\n";
533 break;
534 case TT_MATRIXMULADD_ARRAY:
535 css << " matOArr[1] = coopMatMulAddNV(matAArr[1], matBArr[1], matCArr[1]);\n";
536 break;
537 }
538
539 if (m_data.testType == TT_COMPOSITE_ARRAY ||
540 m_data.testType == TT_MATRIXMULADD_ARRAY)
541 {
542 css << " matOArr[0] = " << outputMatType.str() << "(0.0);\n";
543 css << " matO = matOArr[1];\n";
544 }
545
546 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
547 {
548 string sharedStride = strides[3] + " / workgroupsX";
549 css << " coopMatStoreNV(matO, sharedO, elementS3, " << sharedStride << ", " << colMajor << ");\n";
550 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
551 css << " if (subgroupElect()) {\n";
552 css << " for (int i = 0; i < " << dims[3].rows << "; ++i) {\n"
553 " for (int j = 0; j < " << dims[3].cols << "; ++j) {\n"
554 " int localElementInput = " << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
555 " int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
556 " outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
557 " }\n"
558 " }\n";
559 css << " }\n";
560 }
561 else
562 {
563 css << " coopMatStoreNV(matO, outputO.x, element3, " << strides[3] << ", " << colMajor << ");\n";
564 }
565
566 css <<
567 "}\n";
568
569 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
570
571 programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
572 }
573
createInstance(Context& context) const574 TestInstance* CooperativeMatrixTestCase::createInstance (Context& context) const
575 {
576 return new CooperativeMatrixTestInstance(context, m_data);
577 }
578
setDataFloat(void *base, VkComponentTypeNV dt, deUint32 i, float value)579 static void setDataFloat(void *base, VkComponentTypeNV dt, deUint32 i, float value)
580 {
581 if (dt == VK_COMPONENT_TYPE_FLOAT32_NV)
582 {
583 ((float *)base)[i] = value;
584 }
585 else
586 {
587 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_NV);
588 ((deFloat16 *)base)[i] = deFloat32To16(value);
589 }
590 }
591
getDataFloat(void *base, VkComponentTypeNV dt, deUint32 i)592 static float getDataFloat(void *base, VkComponentTypeNV dt, deUint32 i)
593 {
594 if (dt == VK_COMPONENT_TYPE_FLOAT32_NV)
595 {
596 return ((float *)base)[i];
597 }
598 else
599 {
600 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_NV);
601 return deFloat16To32(((deFloat16 *)base)[i]);
602 }
603 }
604
setDataInt(void *base, VkComponentTypeNV dt, deUint32 i, deUint32 value)605 static void setDataInt(void *base, VkComponentTypeNV dt, deUint32 i, deUint32 value)
606 {
607 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
608 switch (dt) {
609 default: DE_ASSERT(0); // fallthrough
610 case VK_COMPONENT_TYPE_UINT8_NV: ((deUint8 *)base)[i] = (deUint8)value; break;
611 case VK_COMPONENT_TYPE_UINT16_NV: ((deUint16 *)base)[i] = (deUint16)value; break;
612 case VK_COMPONENT_TYPE_UINT32_NV: ((deUint32 *)base)[i] = (deUint32)value; break;
613 case VK_COMPONENT_TYPE_SINT8_NV: ((deInt8 *)base)[i] = (deInt8)value; break;
614 case VK_COMPONENT_TYPE_SINT16_NV: ((deInt16 *)base)[i] = (deInt16)value; break;
615 case VK_COMPONENT_TYPE_SINT32_NV: ((deInt32 *)base)[i] = (deInt32)value; break;
616 }
617 }
618
getDataInt(void *base, VkComponentTypeNV dt, deUint32 i)619 static deUint32 getDataInt(void *base, VkComponentTypeNV dt, deUint32 i)
620 {
621 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
622 switch (dt) {
623 default: DE_ASSERT(0); // fallthrough
624 case VK_COMPONENT_TYPE_UINT8_NV: return ((deUint8 *)base)[i];
625 case VK_COMPONENT_TYPE_UINT16_NV: return ((deUint16 *)base)[i];
626 case VK_COMPONENT_TYPE_UINT32_NV: return ((deUint32 *)base)[i];
627 case VK_COMPONENT_TYPE_SINT8_NV: return ((deInt8 *)base)[i];
628 case VK_COMPONENT_TYPE_SINT16_NV: return ((deInt16 *)base)[i];
629 case VK_COMPONENT_TYPE_SINT32_NV: return ((deInt32 *)base)[i];
630 }
631 }
632
iterate(void)633 tcu::TestStatus CooperativeMatrixTestInstance::iterate (void)
634 {
635 const DeviceInterface& vk = m_context.getDeviceInterface();
636 const VkDevice device = m_context.getDevice();
637 Allocator& allocator = m_context.getDefaultAllocator();
638 MemoryRequirement memoryDeviceAddress = m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
639 m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ? MemoryRequirement::DeviceAddress : MemoryRequirement::Any;
640 qpTestResult finalres = QP_TEST_RESULT_PASS;
641 tcu::TestLog& log = m_context.getTestContext().getLog();
642
643 deRandom rnd;
644 deRandom_init(&rnd, 1234);
645
646 vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
647 deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
648 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
649
650 vk::VkPhysicalDeviceProperties2 properties2;
651 deMemset(&properties2, 0, sizeof(properties2));
652 properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
653 properties2.pNext = &subgroupProperties;
654
655 m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
656
657 deUint32 propertyCount = 0;
658 VkCooperativeMatrixPropertiesNV *pProperties;
659 m_context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(m_context.getPhysicalDevice(), &propertyCount, DE_NULL);
660 // Shouldn't have made it through checkSupport without any properties
661 DE_ASSERT(propertyCount != 0);
662
663 pProperties = new VkCooperativeMatrixPropertiesNV[propertyCount];
664
665 for (deUint32 i = 0; i < propertyCount; ++i)
666 {
667 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
668 p->sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_NV;
669 p->pNext = DE_NULL;
670 }
671
672 m_context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(m_context.getPhysicalDevice(), &propertyCount, pProperties);
673
674 struct TestTuple
675 {
676 TestTuple() {}
677 TestTuple(deUint32 m, deUint32 n, deUint32 k) : M(m), N(n), K(k) {}
678
679 bool operator<(const TestTuple &other) const
680 {
681 return M < other.M ||
682 (M == other.M && N < other.N) ||
683 (M == other.M && N == other.N && K < other.K);
684 }
685
686 deUint32 M, N, K;
687 };
688
689 vector<TestTuple> testSizes;
690
691 if (m_data.testType == TT_MATRIXMULADD ||
692 m_data.testType == TT_MATRIXMULADD_ARRAY)
693 {
694 for (deUint32 i = 0; i < propertyCount; ++i)
695 {
696 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
697
698 if (p->AType == m_data.inputType &&
699 p->BType == m_data.inputType &&
700 p->CType == m_data.outputType &&
701 p->DType == m_data.outputType &&
702 p->scope == VK_SCOPE_SUBGROUP_NV)
703 {
704 testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
705 }
706 }
707 }
708 else
709 {
710 set<TestTuple> typeSizes[2];
711 VkComponentTypeNV types[2] = { m_data.inputType, m_data.outputType };
712
713 for (deUint32 i = 0; i < propertyCount; ++i)
714 {
715 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
716
717 if (p->scope != VK_SCOPE_SUBGROUP_NV)
718 continue;
719
720 for (deUint32 j = 0; j < 2; ++j)
721 {
722 // For these tests, m_data.M/N are always the matrix size. Check if they match
723 // any input or output in the list.
724 if (p->AType == types[j])
725 typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
726 if (p->BType == types[j])
727 typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
728 if (p->CType == types[j] ||
729 p->DType == types[j])
730 typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
731 }
732 }
733 // Test those sizes that are supported for both the input and output type.
734 std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(),
735 typeSizes[1].begin(), typeSizes[1].end(),
736 std::back_inserter(testSizes));
737 }
738
739 delete [] pProperties;
740
741 for (unsigned int s = 0; s < testSizes.size(); ++s)
742 {
743 // When testing a multiply, MxNxK is the type of matrix multiply.
744 // Otherwise, MxN is the size of the input/output matrices
745 deUint32 M, N, K;
746 M = testSizes[s].M;
747 N = testSizes[s].N;
748 K = testSizes[s].K;
749
750 log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
751
752 struct
753 {
754 deUint32 rows, cols;
755 } dims[4];
756
757 if (m_data.testType == TT_MATRIXMULADD ||
758 m_data.testType == TT_MATRIXMULADD_ARRAY)
759 {
760 dims[0].rows = M;
761 dims[0].cols = K;
762 dims[1].rows = K;
763 dims[1].cols = N;
764 dims[2].rows = M;
765 dims[2].cols = N;
766 dims[3].rows = M;
767 dims[3].cols = N;
768 }
769 else
770 {
771 dims[0].rows = M;
772 dims[0].cols = N;
773 dims[1].rows = M;
774 dims[1].cols = N;
775 dims[2].rows = M;
776 dims[2].cols = N;
777 dims[3].rows = M;
778 dims[3].cols = N;
779 }
780
781 VkComponentTypeNV dataTypes[4];
782 size_t elementSize[4];
783 VkDeviceSize bufferSizes[5];
784 de::MovePtr<BufferWithMemory> buffers[5];
785 vk::VkDescriptorBufferInfo bufferDescriptors[5];
786 deUint32 strides[4]; // in elements
787 deUint32 totalElements[4];
788
789 for (deUint32 i = 0; i < 5; ++i)
790 {
791 if (i < 4)
792 {
793 // A/B use input type, C/D use output type
794 dataTypes[i] = (i < 2) ? m_data.inputType : m_data.outputType;
795 elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
796
797 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX * m_data.workgroupsX;
798 totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) * m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
799
800 bufferSizes[i] = totalElements[i] * elementSize[i];
801 }
802 else
803 {
804 bufferSizes[4] = sizeof(VkDeviceAddress)*4;
805 }
806
807 try
808 {
809 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
810 vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
811 (memoryDeviceAddress == MemoryRequirement::DeviceAddress ? VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
812 MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent | memoryDeviceAddress));
813 }
814 catch (const tcu::NotSupportedError&)
815 {
816 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
817 vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
818 (memoryDeviceAddress == MemoryRequirement::DeviceAddress ? VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
819 MemoryRequirement::HostVisible | memoryDeviceAddress));
820 }
821
822 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
823 }
824
825 void *ptrs[5];
826 for (deUint32 i = 0; i < 5; ++i)
827 {
828 ptrs[i] = buffers[i]->getAllocation().getHostPtr();
829 }
830
831 vk::DescriptorSetLayoutBuilder layoutBuilder;
832
833 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
834 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
835 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
836 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
837 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
838
839 vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
840
841 vk::Unique<vk::VkDescriptorPool> descriptorPool(vk::DescriptorPoolBuilder()
842 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
843 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
844 vk::Unique<vk::VkDescriptorSet> descriptorSet (makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
845
846 vk::DescriptorSetUpdateBuilder setUpdateBuilder;
847 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
848 {
849 VkBufferDeviceAddressInfo info
850 {
851 VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType sType;
852 DE_NULL, // const void* pNext;
853 0, // VkBuffer buffer
854 };
855 VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
856 for (deUint32 i = 0; i < 4; ++i)
857 {
858 info.buffer = **buffers[i];
859 VkDeviceAddress addr = vk.getBufferDeviceAddress(device, &info);
860 addrsInMemory[i] = addr;
861 }
862 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
863 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
864 }
865 else
866 {
867 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
868 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
869 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
870 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
871 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
872 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
873 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
874 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
875 }
876
877 setUpdateBuilder.update(vk, device);
878
879 const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
880 {
881 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
882 DE_NULL, // pNext
883 (VkPipelineLayoutCreateFlags)0,
884 1, // setLayoutCount
885 &descriptorSetLayout.get(), // pSetLayouts
886 0u, // pushConstantRangeCount
887 DE_NULL, // pPushConstantRanges
888 };
889
890 Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
891
892 Move<VkPipeline> pipeline;
893
894 VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
895
896 const deUint32 specData[9] =
897 {
898 subgroupProperties.subgroupSize * m_data.subgroupsPerWorkgroupX,
899 m_data.subgroupsPerWorkgroupY,
900 strides[0],
901 strides[1],
902 strides[2],
903 strides[3],
904 M,
905 N,
906 K,
907 };
908
909 const vk::VkSpecializationMapEntry entries[9] =
910 {
911 {0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
912 {1, (deUint32)(sizeof(deUint32) * 1), sizeof(deUint32)},
913 {2, (deUint32)(sizeof(deUint32) * 2), sizeof(deUint32)},
914 {3, (deUint32)(sizeof(deUint32) * 3), sizeof(deUint32)},
915 {4, (deUint32)(sizeof(deUint32) * 4), sizeof(deUint32)},
916 {5, (deUint32)(sizeof(deUint32) * 5), sizeof(deUint32)},
917 {6, (deUint32)(sizeof(deUint32) * 6), sizeof(deUint32)},
918 {7, (deUint32)(sizeof(deUint32) * 7), sizeof(deUint32)},
919 {8, (deUint32)(sizeof(deUint32) * 8), sizeof(deUint32)},
920 };
921
922 const vk::VkSpecializationInfo specInfo =
923 {
924 9, // mapEntryCount
925 entries, // pMapEntries
926 sizeof(specData), // dataSize
927 specData // pData
928 };
929
930 for (deUint32 i = 0; i < 4; ++i)
931 for (deUint32 j = 0; j < totalElements[i]; ++j)
932 {
933 if (isFloatType(dataTypes[i]))
934 {
935 if (m_data.testType != TT_MATRIXMULADD &&
936 m_data.testType != TT_MATRIXMULADD_ARRAY)
937 setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f)/2.0f);
938 else
939 setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f)/2.0f);
940 }
941 else
942 setDataInt(ptrs[i], dataTypes[i], j, (deRandom_getUint32(&rnd) & 0xff) - 128);
943 }
944
945 flushAlloc(vk, device, buffers[0]->getAllocation());
946 flushAlloc(vk, device, buffers[1]->getAllocation());
947 flushAlloc(vk, device, buffers[2]->getAllocation());
948 flushAlloc(vk, device, buffers[3]->getAllocation());
949
950 const Unique<VkShaderModule> shader (createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
951
952 const VkPipelineShaderStageCreateInfo shaderCreateInfo =
953 {
954 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
955 DE_NULL,
956 (VkPipelineShaderStageCreateFlags)0,
957 VK_SHADER_STAGE_COMPUTE_BIT, // stage
958 *shader, // shader
959 "main",
960 &specInfo, // pSpecializationInfo
961 };
962
963 const VkComputePipelineCreateInfo pipelineCreateInfo =
964 {
965 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
966 DE_NULL,
967 0u, // flags
968 shaderCreateInfo, // cs
969 *pipelineLayout, // layout
970 (vk::VkPipeline)0, // basePipelineHandle
971 0u, // basePipelineIndex
972 };
973 pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
974
975 const VkQueue queue = m_context.getUniversalQueue();
976 Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
977 Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
978
979 beginCommandBuffer(vk, *cmdBuffer, 0u);
980
981 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
982 vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
983
984 vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
985
986 endCommandBuffer(vk, *cmdBuffer);
987
988 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
989
990 invalidateAlloc(vk, device, buffers[3]->getAllocation());
991
992 qpTestResult res = QP_TEST_RESULT_PASS;
993
994 if (isFloatType(dataTypes[0]))
995 {
996 if (m_data.testType != TT_MATRIXMULADD &&
997 m_data.testType != TT_MATRIXMULADD_ARRAY)
998 {
999 for (deUint32 i = 0; i < totalElements[3]; ++i)
1000 {
1001 float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1002 float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1003 float output = getDataFloat(ptrs[3], dataTypes[3], i);
1004 switch (m_data.testType)
1005 {
1006 case TT_LENGTH:
1007 if (output < 1.0f || output > (float)(N*M))
1008 res = QP_TEST_RESULT_FAIL;
1009 // We expect the matrix to be spread evenly across invocations, it is
1010 // surprising (but not necessarily illegal) if not
1011 if (output != (float)(N*M/subgroupProperties.subgroupSize) &&
1012 res == QP_TEST_RESULT_PASS)
1013 res = QP_TEST_RESULT_QUALITY_WARNING;
1014 break;
1015 case TT_CONSTANT:
1016 if (output != 1.0f)
1017 res = QP_TEST_RESULT_FAIL;
1018 break;
1019 case TT_CONVERT:
1020 if (output != inputA)
1021 res = QP_TEST_RESULT_FAIL;
1022 break;
1023 case TT_COMPOSITE:
1024 case TT_COMPOSITE_RVALUE:
1025 case TT_COMPOSITE_ARRAY:
1026 case TT_ADD:
1027 if (output != inputA + inputB)
1028 res = QP_TEST_RESULT_FAIL;
1029 break;
1030 case TT_SUB:
1031 if (output != inputA - inputB)
1032 res = QP_TEST_RESULT_FAIL;
1033 break;
1034 case TT_DIV:
1035 {
1036 float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_NV) ? 1.0f/1024.0f : 1.0f/(8.0f*1024.0f*1024.0f);
1037 // division allows 2.5ulp, but we'll use 3.
1038 ulp *= 3;
1039 if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1040 res = QP_TEST_RESULT_FAIL;
1041 }
1042 break;
1043 case TT_NEGATE:
1044 case TT_FUNC:
1045 if (output != -inputA)
1046 res = QP_TEST_RESULT_FAIL;
1047 break;
1048 case TT_MATRIXTIMESSCALAR:
1049 if (output != 6.0*inputA)
1050 res = QP_TEST_RESULT_FAIL;
1051 break;
1052 default:
1053 break;
1054 }
1055 }
1056 }
1057 else
1058 {
1059 deUint32 ik, kj, ij;
1060 for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1061 {
1062 for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1063 {
1064 for (deUint32 i = 0; i < M; ++i)
1065 {
1066 for (deUint32 j = 0; j < N; ++j)
1067 {
1068 float ref = 0;
1069 for (deUint32 k = 0; k < K; ++k)
1070 {
1071 if (m_data.colMajor)
1072 ik = mX * M + i + strides[0] * (mY * K + k);
1073 else
1074 ik = mX * K + k + strides[0] * (mY * M + i);
1075
1076 float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1077
1078 if (m_data.colMajor)
1079 kj = mX * K + k + strides[1] * (mY * N + j);
1080 else
1081 kj = mX * N + j + strides[1] * (mY * K + k);
1082
1083 float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1084
1085 ref += Aik*Bkj;
1086 }
1087
1088 if (m_data.colMajor)
1089 ij = mX * M + i + strides[2] * (mY * N + j);
1090 else
1091 ij = mX * N + j + strides[2] * (mY * M + i);
1092
1093 float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1094
1095 ref += Cij;
1096
1097 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1098
1099 if (ref != Dij)
1100 {
1101 res = QP_TEST_RESULT_FAIL;
1102 }
1103 }
1104 }
1105 }
1106 }
1107 }
1108 } else {
1109 if (m_data.testType != TT_MATRIXMULADD &&
1110 m_data.testType != TT_MATRIXMULADD_ARRAY)
1111 {
1112 for (deUint32 i = 0; i < totalElements[3]; ++i)
1113 {
1114 deUint32 inputA = getDataInt(ptrs[0], dataTypes[0], i);
1115 deUint32 inputB = getDataInt(ptrs[1], dataTypes[1], i);
1116 deUint32 output = getDataInt(ptrs[3], dataTypes[3], i);
1117 int resultSize = componentTypeInfo[dataTypes[3]].bits;
1118 deUint32 mask = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1119 switch (m_data.testType)
1120 {
1121 case TT_LENGTH:
1122 if (output < 1 || output > N*M)
1123 res = QP_TEST_RESULT_FAIL;
1124 // We expect the matrix to be spread evenly across invocations, it is
1125 // surprising (but not necessarily illegal) if not
1126 if (output != N*M/subgroupProperties.subgroupSize &&
1127 res == QP_TEST_RESULT_PASS)
1128 res = QP_TEST_RESULT_QUALITY_WARNING;
1129 break;
1130 case TT_CONSTANT:
1131 if (output != 1)
1132 res = QP_TEST_RESULT_FAIL;
1133 break;
1134 case TT_CONVERT:
1135 if (output != inputA)
1136 res = QP_TEST_RESULT_FAIL;
1137 break;
1138 case TT_COMPOSITE:
1139 case TT_COMPOSITE_RVALUE:
1140 case TT_COMPOSITE_ARRAY:
1141 case TT_ADD:
1142 if ((output & mask) != ((inputA + inputB) & mask)) {
1143 res = QP_TEST_RESULT_FAIL;
1144 }
1145 break;
1146 case TT_SUB:
1147 if ((output & mask) != ((inputA - inputB) & mask))
1148 res = QP_TEST_RESULT_FAIL;
1149 break;
1150 case TT_DIV:
1151 {
1152 if (isSIntType(dataTypes[3]))
1153 {
1154 if (inputB != 0 && ((deInt32)output & mask) != (((deInt32)inputA / (deInt32)inputB) & mask))
1155 res = QP_TEST_RESULT_FAIL;
1156 } else
1157 {
1158 if (inputB != 0 && output != inputA / inputB)
1159 res = QP_TEST_RESULT_FAIL;
1160 }
1161 }
1162 break;
1163 case TT_NEGATE:
1164 case TT_FUNC:
1165 if ((output & mask) != ((-(deInt32)inputA) & mask))
1166 res = QP_TEST_RESULT_FAIL;
1167 break;
1168 case TT_MATRIXTIMESSCALAR:
1169 if ((output & mask) != ((6*inputA) & mask)) {
1170 res = QP_TEST_RESULT_FAIL;
1171 }
1172 break;
1173 default:
1174 break;
1175 }
1176 }
1177 }
1178 else
1179 {
1180 deUint32 ik, kj, ij;
1181 for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1182 {
1183 for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1184 {
1185 for (deUint32 i = 0; i < M; ++i)
1186 {
1187 for (deUint32 j = 0; j < N; ++j)
1188 {
1189 deUint32 ref = 0;
1190 for (deUint32 k = 0; k < K; ++k)
1191 {
1192 if (m_data.colMajor)
1193 ik = mX * M + i + strides[0] * (mY * K + k);
1194 else
1195 ik = mX * K + k + strides[0] * (mY * M + i);
1196
1197 deUint32 Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1198
1199 if (m_data.colMajor)
1200 kj = mX * K + k + strides[1] * (mY * N + j);
1201 else
1202 kj = mX * N + j + strides[1] * (mY * K + k);
1203
1204 deUint32 Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1205
1206 ref += Aik*Bkj;
1207 }
1208
1209 if (m_data.colMajor)
1210 ij = mX * M + i + strides[2] * (mY * N + j);
1211 else
1212 ij = mX * N + j + strides[2] * (mY * M + i);
1213
1214 deUint32 Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1215
1216 ref += Cij;
1217
1218 deUint32 Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1219
1220 if (ref != Dij)
1221 {
1222 res = QP_TEST_RESULT_FAIL;
1223 }
1224 }
1225 }
1226 }
1227 }
1228 }
1229 }
1230 if (res != QP_TEST_RESULT_PASS)
1231 {
1232 log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1233 finalres = res;
1234 }
1235 }
1236
1237 return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1238 }
1239
1240 } // anonymous
1241
createCooperativeMatrixTests(tcu::TestContext& testCtx)1242 tcu::TestCaseGroup* createCooperativeMatrixTests (tcu::TestContext& testCtx)
1243 {
1244 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1245 testCtx, "cooperative_matrix", "GL_NV_cooperative_matrix tests"));
1246
1247 typedef struct
1248 {
1249 deUint32 value;
1250 const char* name;
1251 const char* description;
1252 } TestGroupCase;
1253
1254 typedef struct
1255 {
1256 deUint32 value[2];
1257 const char* name;
1258 const char* description;
1259 } TestGroupCase2;
1260
1261 TestGroupCase ttCases[] =
1262 {
1263 { TT_LENGTH, "length", "OpCooperativeMatrixLengthNV" },
1264 { TT_CONSTANT, "constant", "OpConstantComposite" },
1265 { TT_CONVERT, "convert", "OpFConvert/OpSConvert/OpUConvert" },
1266 { TT_COMPOSITE, "composite", "OpCompositeConstruct" },
1267 { TT_COMPOSITE_RVALUE, "composite_rvalue", "OpCompositeExtract" },
1268 { TT_ADD, "add", "OpFAdd/OpIAdd" },
1269 { TT_SUB, "sub", "OpFSub/OpISub" },
1270 { TT_DIV, "div", "OpFDiv/OpSDiv/OpUDiv" },
1271 { TT_NEGATE, "negate", "OpFNegate/OpSNegate" },
1272 { TT_MATRIXTIMESSCALAR, "matrixtimesscalar", "OpMatrixTimesScalar" },
1273 { TT_FUNC, "func", "OpFunctionParameter" },
1274 { TT_MATRIXMULADD, "matrixmuladd", "OpCooperativeMatrixMulAddNV" },
1275 { TT_COMPOSITE_ARRAY, "composite_array", "OpCompositeConstruct w/array" },
1276 { TT_MATRIXMULADD_ARRAY, "matrixmuladd_array", "OpCooperativeMatrixMulAddNV w/array" },
1277 };
1278
1279 TestGroupCase2 dtCases[] =
1280 {
1281 { { VK_COMPONENT_TYPE_FLOAT32_NV, VK_COMPONENT_TYPE_FLOAT32_NV }, "float32_float32", "A/B are fp32 C/D are fp32" },
1282 { { VK_COMPONENT_TYPE_FLOAT32_NV, VK_COMPONENT_TYPE_FLOAT16_NV }, "float32_float16", "A/B are fp32 C/D are fp16" },
1283 { { VK_COMPONENT_TYPE_FLOAT16_NV, VK_COMPONENT_TYPE_FLOAT32_NV }, "float16_float32", "A/B are fp16 C/D are fp32" },
1284 { { VK_COMPONENT_TYPE_FLOAT16_NV, VK_COMPONENT_TYPE_FLOAT16_NV }, "float16_float16", "A/B are fp16 C/D are fp16" },
1285 { { VK_COMPONENT_TYPE_UINT8_NV, VK_COMPONENT_TYPE_UINT8_NV }, "uint8_uint8", "A/B are u8 C/D are u8" },
1286 { { VK_COMPONENT_TYPE_UINT8_NV, VK_COMPONENT_TYPE_UINT32_NV }, "uint8_uint32", "A/B are u8 C/D are u32" },
1287 { { VK_COMPONENT_TYPE_SINT8_NV, VK_COMPONENT_TYPE_SINT8_NV }, "sint8_sint8", "A/B are s8 C/D are s8" },
1288 { { VK_COMPONENT_TYPE_SINT8_NV, VK_COMPONENT_TYPE_SINT32_NV }, "sint8_sint32", "A/B are s8 C/D are s32" },
1289 { { VK_COMPONENT_TYPE_UINT32_NV, VK_COMPONENT_TYPE_UINT32_NV }, "uint32_uint32", "A/B are u32 C/D are u32" },
1290 { { VK_COMPONENT_TYPE_UINT32_NV, VK_COMPONENT_TYPE_UINT8_NV }, "uint32_uint8", "A/B are u32 C/D are u8" },
1291 { { VK_COMPONENT_TYPE_SINT32_NV, VK_COMPONENT_TYPE_SINT32_NV }, "sint32_sint32", "A/B are s32 C/D are s32" },
1292 { { VK_COMPONENT_TYPE_SINT32_NV, VK_COMPONENT_TYPE_SINT8_NV }, "sint32_sint8", "A/B are s32 C/D are s8" },
1293 };
1294
1295 TestGroupCase colCases[] =
1296 {
1297 { 0, "rowmajor", "row major" },
1298 { 1, "colmajor", "col major" },
1299 };
1300
1301 TestGroupCase scCases[] =
1302 {
1303 { SC_BUFFER, "buffer", "SSBO" },
1304 { SC_WORKGROUP, "workgroup", "shared memory" },
1305 { SC_BUFFER_VARIABLE_POINTERS, "buffer_varptr", "SSBO w/variable pointers" },
1306 { SC_WORKGROUP_VARIABLE_POINTERS, "workgroup_varptr", "shared memory w/variable pointers" },
1307 { SC_PHYSICAL_STORAGE_BUFFER, "physical_buffer", "physical_storage_buffer" },
1308 };
1309
1310 for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1311 {
1312 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1313 for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
1314 {
1315 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name, dtCases[dtNdx].description));
1316 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1317 {
1318 de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name, scCases[scNdx].description));
1319 for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1320 {
1321 TestType testType = (TestType)ttCases[ttNdx].value;
1322 VkComponentTypeNV inputType = (VkComponentTypeNV)dtCases[dtNdx].value[0];
1323 VkComponentTypeNV outputType = (VkComponentTypeNV)dtCases[dtNdx].value[1];
1324
1325 bool isMatrixMul = testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY;
1326
1327 if (!isMatrixMul && testType != TT_CONVERT && inputType != outputType)
1328 continue;
1329
1330 if (testType == TT_CONVERT && inputType == outputType)
1331 continue;
1332
1333 if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
1334 continue;
1335
1336 CaseDef c =
1337 {
1338 testType, // TestType testtype;
1339 2u, // deUint32 subgroupsPerWorkgroupX;
1340 2u, // deUint32 subgroupsPerWorkgroupY;
1341 4u, // deUint32 workgroupsX;
1342 4u, // deUint32 workgroupsY;
1343 (VkComponentTypeNV)inputType, // VkComponentTypeNV inputType;
1344 (VkComponentTypeNV)outputType, // VkComponentTypeNV outputType;
1345 !!colCases[colNdx].value, // bool colMajor;
1346 (StorageClass)scCases[scNdx].value, // StorageClass storageClass;
1347 };
1348
1349 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, colCases[colNdx].description, c));
1350 }
1351 dtGroup->addChild(scGroup.release());
1352 }
1353 ttGroup->addChild(dtGroup.release());
1354 }
1355 group->addChild(ttGroup.release());
1356 }
1357 return group.release();
1358 }
1359
1360 } // compute
1361 } // vkt
1362