1/*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *	  http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 *//*!
20 * \file
21 * \brief Ray Tracing Complex Control Flow tests
22 *//*--------------------------------------------------------------------*/
23
24#include "vktRayTracingComplexControlFlowTests.hpp"
25
26#include "vkDefs.hpp"
27
28#include "vktTestCase.hpp"
29#include "vkCmdUtil.hpp"
30#include "vkObjUtil.hpp"
31#include "vkBuilderUtil.hpp"
32#include "vkBarrierUtil.hpp"
33#include "vkBufferWithMemory.hpp"
34#include "vkImageWithMemory.hpp"
35#include "vkTypeUtil.hpp"
36
37#include "vkRayTracingUtil.hpp"
38
39#include "tcuTestLog.hpp"
40
41#include "deRandom.hpp"
42
43namespace vkt
44{
45namespace RayTracing
46{
47namespace
48{
49using namespace vk;
50using namespace std;
51
52static const VkFlags	ALL_RAY_TRACING_STAGES	= VK_SHADER_STAGE_RAYGEN_BIT_KHR
53												| VK_SHADER_STAGE_ANY_HIT_BIT_KHR
54												| VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
55												| VK_SHADER_STAGE_MISS_BIT_KHR
56												| VK_SHADER_STAGE_INTERSECTION_BIT_KHR
57												| VK_SHADER_STAGE_CALLABLE_BIT_KHR;
58
59#if defined(DE_DEBUG)
60static const deUint32	PUSH_CONSTANTS_COUNT	= 6;
61#endif
62static const deUint32	DEFAULT_CLEAR_VALUE		= 999999;
63
64enum TestType
65{
66	TEST_TYPE_IF						= 0,
67	TEST_TYPE_LOOP,
68	TEST_TYPE_SWITCH,
69	TEST_TYPE_LOOP_DOUBLE_CALL,
70	TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE,
71	TEST_TYPE_NESTED_LOOP,
72	TEST_TYPE_NESTED_LOOP_BEFORE,
73	TEST_TYPE_NESTED_LOOP_AFTER,
74	TEST_TYPE_FUNCTION_CALL,
75	TEST_TYPE_NESTED_FUNCTION_CALL,
76};
77
78enum TestOp
79{
80	TEST_OP_EXECUTE_CALLABLE		= 0,
81	TEST_OP_TRACE_RAY,
82	TEST_OP_REPORT_INTERSECTION,
83};
84
85enum ShaderGroups
86{
87	FIRST_GROUP		= 0,
88	RAYGEN_GROUP	= FIRST_GROUP,
89	MISS_GROUP,
90	HIT_GROUP,
91	GROUP_COUNT
92};
93
94struct CaseDef
95{
96	TestType				testType;
97	TestOp					testOp;
98	VkShaderStageFlagBits	stage;
99	deUint32				width;
100	deUint32				height;
101};
102
103struct PushConstants
104{
105	deUint32	a;
106	deUint32	b;
107	deUint32	c;
108	deUint32	d;
109	deUint32	hitOfs;
110	deUint32	miss;
111};
112
113deUint32 getShaderGroupSize (const InstanceInterface&	vki,
114							 const VkPhysicalDevice		physicalDevice)
115{
116	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
117
118	rayTracingPropertiesKHR	= makeRayTracingProperties(vki, physicalDevice);
119	return rayTracingPropertiesKHR->getShaderGroupHandleSize();
120}
121
122deUint32 getShaderGroupBaseAlignment (const InstanceInterface&	vki,
123									  const VkPhysicalDevice	physicalDevice)
124{
125	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
126
127	rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
128	return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
129}
130
131VkImageCreateInfo makeImageCreateInfo (deUint32 width, deUint32 height, deUint32 depth, VkFormat format)
132{
133	const VkImageUsageFlags	usage			= VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
134	const VkImageCreateInfo	imageCreateInfo	=
135	{
136		VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,	// VkStructureType			sType;
137		DE_NULL,								// const void*				pNext;
138		(VkImageCreateFlags)0u,					// VkImageCreateFlags		flags;
139		VK_IMAGE_TYPE_3D,						// VkImageType				imageType;
140		format,									// VkFormat					format;
141		makeExtent3D(width, height, depth),		// VkExtent3D				extent;
142		1u,										// deUint32					mipLevels;
143		1u,										// deUint32					arrayLayers;
144		VK_SAMPLE_COUNT_1_BIT,					// VkSampleCountFlagBits	samples;
145		VK_IMAGE_TILING_OPTIMAL,				// VkImageTiling			tiling;
146		usage,									// VkImageUsageFlags		usage;
147		VK_SHARING_MODE_EXCLUSIVE,				// VkSharingMode			sharingMode;
148		0u,										// deUint32					queueFamilyIndexCount;
149		DE_NULL,								// const deUint32*			pQueueFamilyIndices;
150		VK_IMAGE_LAYOUT_UNDEFINED				// VkImageLayout			initialLayout;
151	};
152
153	return imageCreateInfo;
154}
155
156Move<VkPipelineLayout> makePipelineLayout (const DeviceInterface&		vk,
157										   const VkDevice				device,
158										   const VkDescriptorSetLayout	descriptorSetLayout,
159										   const deUint32				pushConstantsSize)
160{
161	const VkDescriptorSetLayout*		descriptorSetLayoutPtr	= (descriptorSetLayout == DE_NULL) ? DE_NULL : &descriptorSetLayout;
162	const deUint32						setLayoutCount			= (descriptorSetLayout == DE_NULL) ? 0u : 1u;
163	const VkPushConstantRange			pushConstantRange		=
164	{
165		ALL_RAY_TRACING_STAGES,		//  VkShaderStageFlags	stageFlags;
166		0u,							//  deUint32			offset;
167		pushConstantsSize,			//  deUint32			size;
168	};
169	const VkPushConstantRange*			pPushConstantRanges		= (pushConstantsSize == 0) ? DE_NULL : &pushConstantRange;
170	const deUint32						pushConstantRangeCount	= (pushConstantsSize == 0) ? 0 : 1u;
171	const VkPipelineLayoutCreateInfo	pipelineLayoutParams	=
172	{
173		VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,	// VkStructureType					sType;
174		DE_NULL,										// const void*						pNext;
175		0u,												// VkPipelineLayoutCreateFlags		flags;
176		setLayoutCount,									// deUint32							setLayoutCount;
177		descriptorSetLayoutPtr,							// const VkDescriptorSetLayout*		pSetLayouts;
178		pushConstantRangeCount,							// deUint32							pushConstantRangeCount;
179		pPushConstantRanges,							// const VkPushConstantRange*		pPushConstantRanges;
180	};
181
182	return createPipelineLayout(vk, device, &pipelineLayoutParams);
183}
184
185VkBuffer getVkBuffer (const de::MovePtr<BufferWithMemory>& buffer)
186{
187	VkBuffer result = (buffer.get() == DE_NULL) ? DE_NULL : buffer->get();
188
189	return result;
190}
191
192VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegion (const DeviceInterface& vkd, const VkDevice device, VkBuffer buffer, deUint32 stride, deUint32 count)
193{
194	if (buffer == DE_NULL)
195	{
196		return makeStridedDeviceAddressRegionKHR(0, 0, 0);
197	}
198	else
199	{
200		return makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, buffer, 0), stride, stride * count);
201	}
202}
203
204// Function replacing all occurrences of substring with string passed in last parameter.
205static inline std::string replace(const std::string& str, const std::string& from, const std::string& to)
206{
207	std::string result(str);
208
209	size_t start_pos = 0;
210	while((start_pos = result.find(from, start_pos)) != std::string::npos)
211	{
212		result.replace(start_pos, from.length(), to);
213		start_pos += to.length();
214	}
215
216	return result;
217}
218
219
220class RayTracingComplexControlFlowInstance : public TestInstance
221{
222public:
223																RayTracingComplexControlFlowInstance	(Context& context, const CaseDef& data);
224																~RayTracingComplexControlFlowInstance	(void);
225	tcu::TestStatus												iterate									(void);
226
227protected:
228	void														calcShaderGroup							(deUint32&					shaderGroupCounter,
229																										 const VkShaderStageFlags	shaders1,
230																										 const VkShaderStageFlags	shaders2,
231																										 const VkShaderStageFlags	shaderStageFlags,
232																										 deUint32&					shaderGroup,
233																										 deUint32&					shaderGroupCount) const;
234	PushConstants												getPushConstants						(void) const;
235	std::vector<deUint32>										getExpectedValues						(void) const;
236	de::MovePtr<BufferWithMemory>								runTest									(void);
237	Move<VkPipeline>											makePipeline							(de::MovePtr<RayTracingPipeline>&							rayTracingPipeline,
238																										 VkPipelineLayout											pipelineLayout);
239	de::MovePtr<BufferWithMemory>								createShaderBindingTable				 (const InstanceInterface&									vki,
240																										 const DeviceInterface&										vkd,
241																										 const VkDevice												device,
242																										 const VkPhysicalDevice										physicalDevice,
243																										 const VkPipeline											pipeline,
244																										 Allocator&													allocator,
245																										 de::MovePtr<RayTracingPipeline>&							rayTracingPipeline,
246																										 const deUint32												group,
247																										 const deUint32												groupCount = 1u);
248	de::MovePtr<TopLevelAccelerationStructure>					initTopAccelerationStructure			(VkCommandBuffer											cmdBuffer,
249																										 vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures);
250	vector<de::SharedPtr<BottomLevelAccelerationStructure>	>	initBottomAccelerationStructures		(VkCommandBuffer											cmdBuffer);
251	de::MovePtr<BottomLevelAccelerationStructure>				initBottomAccelerationStructure			(VkCommandBuffer											cmdBuffer,
252																										 tcu::UVec2&												startPos);
253
254private:
255	CaseDef														m_data;
256	VkShaderStageFlags											m_shaders;
257	VkShaderStageFlags											m_shaders2;
258	deUint32													m_raygenShaderGroup;
259	deUint32													m_missShaderGroup;
260	deUint32													m_hitShaderGroup;
261	deUint32													m_callableShaderGroup;
262	deUint32													m_raygenShaderGroupCount;
263	deUint32													m_missShaderGroupCount;
264	deUint32													m_hitShaderGroupCount;
265	deUint32													m_callableShaderGroupCount;
266	deUint32													m_shaderGroupCount;
267	deUint32													m_depth;
268	PushConstants												m_pushConstants;
269};
270
271RayTracingComplexControlFlowInstance::RayTracingComplexControlFlowInstance (Context& context, const CaseDef& data)
272	: vkt::TestInstance				(context)
273	, m_data						(data)
274	, m_shaders						(0)
275	, m_shaders2					(0)
276	, m_raygenShaderGroup			(~0u)
277	, m_missShaderGroup				(~0u)
278	, m_hitShaderGroup				(~0u)
279	, m_callableShaderGroup			(~0u)
280	, m_raygenShaderGroupCount		(0)
281	, m_missShaderGroupCount		(0)
282	, m_hitShaderGroupCount			(0)
283	, m_callableShaderGroupCount	(0)
284	, m_shaderGroupCount			(0)
285	, m_depth						(16)
286	, m_pushConstants				(getPushConstants())
287{
288	const VkShaderStageFlags	hitStages	= VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
289	BinaryCollection&			collection	= m_context.getBinaryCollection();
290	deUint32					shaderCount	= 0;
291
292	if (collection.contains("rgen")) m_shaders |= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
293	if (collection.contains("ahit")) m_shaders |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
294	if (collection.contains("chit")) m_shaders |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
295	if (collection.contains("miss")) m_shaders |= VK_SHADER_STAGE_MISS_BIT_KHR;
296	if (collection.contains("sect")) m_shaders |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
297	if (collection.contains("call")) m_shaders |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
298
299	if (collection.contains("ahit2")) m_shaders2 |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
300	if (collection.contains("chit2")) m_shaders2 |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
301	if (collection.contains("miss2")) m_shaders2 |= VK_SHADER_STAGE_MISS_BIT_KHR;
302	if (collection.contains("sect2")) m_shaders2 |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
303
304	if (collection.contains("cal0")) m_shaders2 |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
305
306	for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
307		shaderCount++;
308
309	if (shaderCount != (deUint32)dePop32(m_shaders) + (deUint32)dePop32(m_shaders2))
310		TCU_THROW(InternalError, "Unused shaders detected in the collection");
311
312	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_RAYGEN_BIT_KHR,   m_raygenShaderGroup,   m_raygenShaderGroupCount);
313	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_MISS_BIT_KHR,     m_missShaderGroup,     m_missShaderGroupCount);
314	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, hitStages,                        m_hitShaderGroup,      m_hitShaderGroupCount);
315	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_CALLABLE_BIT_KHR, m_callableShaderGroup, m_callableShaderGroupCount);
316}
317
318RayTracingComplexControlFlowInstance::~RayTracingComplexControlFlowInstance (void)
319{
320}
321
322void RayTracingComplexControlFlowInstance::calcShaderGroup (deUint32&					shaderGroupCounter,
323															const VkShaderStageFlags	shaders1,
324															const VkShaderStageFlags	shaders2,
325															const VkShaderStageFlags	shaderStageFlags,
326															deUint32&					shaderGroup,
327															deUint32&					shaderGroupCount) const
328{
329	const deUint32	shader1Count = ((shaders1 & shaderStageFlags) != 0) ? 1 : 0;
330	const deUint32	shader2Count = ((shaders2 & shaderStageFlags) != 0) ? 1 : 0;
331
332	shaderGroupCount = shader1Count + shader2Count;
333
334	if (shaderGroupCount != 0)
335	{
336		shaderGroup			= shaderGroupCounter;
337		shaderGroupCounter += shaderGroupCount;
338	}
339}
340
341Move<VkPipeline> RayTracingComplexControlFlowInstance::makePipeline (de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
342																	  VkPipelineLayout					pipelineLayout)
343{
344	const DeviceInterface&	vkd			= m_context.getDeviceInterface();
345	const VkDevice			device		= m_context.getDevice();
346	vk::BinaryCollection&	collection	= m_context.getBinaryCollection();
347
348	if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR		, createShaderModule(vkd, device, collection.get("rgen"), 0), m_raygenShaderGroup);
349	if (0 != (m_shaders & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR		, createShaderModule(vkd, device, collection.get("ahit"), 0), m_hitShaderGroup);
350	if (0 != (m_shaders & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR	, createShaderModule(vkd, device, collection.get("chit"), 0), m_hitShaderGroup);
351	if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR			, createShaderModule(vkd, device, collection.get("miss"), 0), m_missShaderGroup);
352	if (0 != (m_shaders & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))	rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR	, createShaderModule(vkd, device, collection.get("sect"), 0), m_hitShaderGroup);
353	if (0 != (m_shaders & VK_SHADER_STAGE_CALLABLE_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR		, createShaderModule(vkd, device, collection.get("call"), 0), m_callableShaderGroup + 1);
354
355	if (0 != (m_shaders2 & VK_SHADER_STAGE_CALLABLE_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR		, createShaderModule(vkd, device, collection.get("cal0"), 0), m_callableShaderGroup);
356	if (0 != (m_shaders2 & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR		, createShaderModule(vkd, device, collection.get("ahit2"), 0), m_hitShaderGroup + 1);
357	if (0 != (m_shaders2 & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))	rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR	, createShaderModule(vkd, device, collection.get("chit2"), 0), m_hitShaderGroup + 1);
358	if (0 != (m_shaders2 & VK_SHADER_STAGE_MISS_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR			, createShaderModule(vkd, device, collection.get("miss2"), 0), m_missShaderGroup + 1);
359	if (0 != (m_shaders2 & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))	rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR	, createShaderModule(vkd, device, collection.get("sect2"), 0), m_hitShaderGroup + 1);
360
361	if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
362		rayTracingPipeline->setMaxRecursionDepth(2);
363
364	Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout);
365
366	return pipeline;
367}
368
369de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::createShaderBindingTable (const InstanceInterface&			vki,
370																							  const DeviceInterface&			vkd,
371																							  const VkDevice					device,
372																							  const VkPhysicalDevice			physicalDevice,
373																							  const VkPipeline					pipeline,
374																							  Allocator&						allocator,
375																							  de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
376																							  const deUint32					group,
377																							  const deUint32					groupCount)
378{
379	de::MovePtr<BufferWithMemory>	shaderBindingTable;
380
381	if (group < m_shaderGroupCount)
382	{
383		const deUint32	shaderGroupHandleSize		= getShaderGroupSize(vki, physicalDevice);
384		const deUint32	shaderGroupBaseAlignment	= getShaderGroupBaseAlignment(vki, physicalDevice);
385
386		shaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, group, groupCount);
387	}
388
389	return shaderBindingTable;
390}
391
392
393de::MovePtr<TopLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initTopAccelerationStructure (VkCommandBuffer												cmdBuffer,
394																											   vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures)
395{
396	const DeviceInterface&						vkd			= m_context.getDeviceInterface();
397	const VkDevice								device		= m_context.getDevice();
398	Allocator&									allocator	= m_context.getDefaultAllocator();
399	de::MovePtr<TopLevelAccelerationStructure>	result		= makeTopLevelAccelerationStructure();
400
401	result->setInstanceCount(bottomLevelAccelerationStructures.size());
402
403	for (size_t structNdx = 0; structNdx < bottomLevelAccelerationStructures.size(); ++structNdx)
404		result->addInstance(bottomLevelAccelerationStructures[structNdx]);
405
406	result->createAndBuild(vkd, device, cmdBuffer, allocator);
407
408	return result;
409}
410
411de::MovePtr<BottomLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initBottomAccelerationStructure (VkCommandBuffer	cmdBuffer,
412																													 tcu::UVec2&		startPos)
413{
414	const DeviceInterface&							vkd				= m_context.getDeviceInterface();
415	const VkDevice									device			= m_context.getDevice();
416	Allocator&										allocator		= m_context.getDefaultAllocator();
417	de::MovePtr<BottomLevelAccelerationStructure>	result			= makeBottomLevelAccelerationStructure();
418	const float										z				= (m_data.stage == VK_SHADER_STAGE_MISS_BIT_KHR) ? +1.0f : -1.0f;
419	std::vector<tcu::Vec3>							geometryData;
420
421	DE_UNREF(startPos);
422
423	result->setGeometryCount(1);
424	geometryData.push_back(tcu::Vec3(0.0f, 0.0f, z));
425	geometryData.push_back(tcu::Vec3(1.0f, 1.0f, z));
426	result->addGeometry(geometryData, false);
427	result->createAndBuild(vkd, device, cmdBuffer, allocator);
428
429	return result;
430}
431
432vector<de::SharedPtr<BottomLevelAccelerationStructure> > RayTracingComplexControlFlowInstance::initBottomAccelerationStructures (VkCommandBuffer	cmdBuffer)
433{
434	tcu::UVec2													startPos;
435	vector<de::SharedPtr<BottomLevelAccelerationStructure> >	result;
436	de::MovePtr<BottomLevelAccelerationStructure>				bottomLevelAccelerationStructure	= initBottomAccelerationStructure(cmdBuffer, startPos);
437
438	result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
439
440	return result;
441}
442
443PushConstants RayTracingComplexControlFlowInstance::getPushConstants (void) const
444{
445	const			deUint32	hitOfs	= 1;
446	const			deUint32	miss	= 1;
447	PushConstants	result;
448
449	switch (m_data.testType)
450	{
451		case TEST_TYPE_IF:
452		{
453			result = { 32 | 8 | 1, 10000, 0x0F, 0xF0, hitOfs, miss };
454
455			break;
456		}
457		case TEST_TYPE_LOOP:
458		{
459			result = { 8, 10000, 0x0F, 100000, hitOfs, miss };
460
461			break;
462		}
463		case TEST_TYPE_SWITCH:
464		{
465			result = { 3, 10000, 0x07, 100000, hitOfs, miss };
466
467			break;
468		}
469		case TEST_TYPE_LOOP_DOUBLE_CALL:
470		{
471			result = { 7, 10000, 0x0F, 0xF0, hitOfs, miss };
472
473			break;
474		}
475		case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
476		{
477			result = { 16, 5, 0x0F, 0xF0, hitOfs, miss };
478
479			break;
480		}
481		case TEST_TYPE_NESTED_LOOP:
482		{
483			result = { 8, 5, 0x0F, 0x09, hitOfs, miss };
484
485			break;
486		}
487		case TEST_TYPE_NESTED_LOOP_BEFORE:
488		{
489			result = { 9, 16, 0x0F, 10, hitOfs, miss };
490
491			break;
492		}
493		case TEST_TYPE_NESTED_LOOP_AFTER:
494		{
495			result = { 9, 16, 0x0F, 10, hitOfs, miss };
496
497			break;
498		}
499		case TEST_TYPE_FUNCTION_CALL:
500		{
501			result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
502
503			break;
504		}
505		case TEST_TYPE_NESTED_FUNCTION_CALL:
506		{
507			result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
508
509			break;
510		}
511
512		default:
513			TCU_THROW(InternalError, "Unknown testType");
514	}
515
516	return result;
517}
518
519de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::runTest (void)
520{
521	const InstanceInterface&				vki									= m_context.getInstanceInterface();
522	const DeviceInterface&					vkd									= m_context.getDeviceInterface();
523	const VkDevice							device								= m_context.getDevice();
524	const VkPhysicalDevice					physicalDevice						= m_context.getPhysicalDevice();
525	const deUint32							queueFamilyIndex					= m_context.getUniversalQueueFamilyIndex();
526	const VkQueue							queue								= m_context.getUniversalQueue();
527	Allocator&								allocator							= m_context.getDefaultAllocator();
528	const VkFormat							format								= VK_FORMAT_R32_UINT;
529	const deUint32							pushConstants[]						= { m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
530	const deUint32							pushConstantsSize					= sizeof(pushConstants);
531	const deUint32							pixelCount							= m_data.width * m_data.height * m_depth;
532	const deUint32							shaderGroupHandleSize				= getShaderGroupSize(vki, physicalDevice);
533
534	const Move<VkDescriptorSetLayout>		descriptorSetLayout					= DescriptorSetLayoutBuilder()
535																						.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
536																						.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
537																						.build(vkd, device);
538	const Move<VkDescriptorPool>			descriptorPool						= DescriptorPoolBuilder()
539																						.addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
540																						.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
541																						.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
542	const Move<VkDescriptorSet>				descriptorSet						= makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
543	const Move<VkPipelineLayout>			pipelineLayout						= makePipelineLayout(vkd, device, descriptorSetLayout.get(), pushConstantsSize);
544	const Move<VkCommandPool>				cmdPool								= createCommandPool(vkd, device, 0, queueFamilyIndex);
545	const Move<VkCommandBuffer>				cmdBuffer							= allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
546
547	de::MovePtr<RayTracingPipeline>			rayTracingPipeline					= de::newMovePtr<RayTracingPipeline>();
548	const Move<VkPipeline>					pipeline							= makePipeline(rayTracingPipeline, *pipelineLayout);
549	const de::MovePtr<BufferWithMemory>		raygenShaderBindingTable			= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_raygenShaderGroup, m_raygenShaderGroupCount);
550	const de::MovePtr<BufferWithMemory>		missShaderBindingTable				= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_missShaderGroup, m_missShaderGroupCount);
551	const de::MovePtr<BufferWithMemory>		hitShaderBindingTable				= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_hitShaderGroup, m_hitShaderGroupCount);
552	const de::MovePtr<BufferWithMemory>		callableShaderBindingTable			= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_callableShaderGroup, m_callableShaderGroupCount);
553
554	const VkStridedDeviceAddressRegionKHR	raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(raygenShaderBindingTable),   shaderGroupHandleSize, m_raygenShaderGroupCount);
555	const VkStridedDeviceAddressRegionKHR	missShaderBindingTableRegion		= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(missShaderBindingTable),     shaderGroupHandleSize, m_missShaderGroupCount);
556	const VkStridedDeviceAddressRegionKHR	hitShaderBindingTableRegion			= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(hitShaderBindingTable),      shaderGroupHandleSize, m_hitShaderGroupCount);
557	const VkStridedDeviceAddressRegionKHR	callableShaderBindingTableRegion	= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(callableShaderBindingTable), shaderGroupHandleSize, m_callableShaderGroupCount);
558
559	const VkImageCreateInfo					imageCreateInfo						= makeImageCreateInfo(m_data.width, m_data.height, m_depth, format);
560	const VkImageSubresourceRange			imageSubresourceRange				= makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0, 1u);
561	const de::MovePtr<ImageWithMemory>		image								= de::MovePtr<ImageWithMemory>(new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
562	const Move<VkImageView>					imageView							= makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, format, imageSubresourceRange);
563
564	const VkBufferCreateInfo				bufferCreateInfo					= makeBufferCreateInfo(pixelCount*sizeof(deUint32), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
565	const VkImageSubresourceLayers			bufferImageSubresourceLayers		= makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
566	const VkBufferImageCopy					bufferImageRegion					= makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, m_depth), bufferImageSubresourceLayers);
567	de::MovePtr<BufferWithMemory>			buffer								= de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, allocator, bufferCreateInfo, MemoryRequirement::HostVisible));
568
569	const VkDescriptorImageInfo				descriptorImageInfo					= makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
570
571	const VkImageMemoryBarrier				preImageBarrier						= makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT,
572																					VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
573																					**image, imageSubresourceRange);
574	const VkImageMemoryBarrier				postImageBarrier					= makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT,
575																					VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL,
576																					**image, imageSubresourceRange);
577	const VkMemoryBarrier					preTraceMemoryBarrier				= makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
578	const VkMemoryBarrier					postTraceMemoryBarrier				= makeMemoryBarrier(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
579	const VkMemoryBarrier					postCopyMemoryBarrier				= makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
580	const VkClearValue						clearValue							= makeClearValueColorU32(DEFAULT_CLEAR_VALUE, 0u, 0u, 255u);
581
582	vector<de::SharedPtr<BottomLevelAccelerationStructure> >	bottomLevelAccelerationStructures;
583	de::MovePtr<TopLevelAccelerationStructure>					topLevelAccelerationStructure;
584
585	DE_ASSERT(DE_LENGTH_OF_ARRAY(pushConstants) == PUSH_CONSTANTS_COUNT);
586
587	beginCommandBuffer(vkd, *cmdBuffer, 0u);
588	{
589		vkd.cmdPushConstants(*cmdBuffer, *pipelineLayout, ALL_RAY_TRACING_STAGES, 0, pushConstantsSize, &m_pushConstants);
590
591		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
592		vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1, &imageSubresourceRange);
593		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &postImageBarrier);
594
595		bottomLevelAccelerationStructures = initBottomAccelerationStructures(*cmdBuffer);
596		topLevelAccelerationStructure = initTopAccelerationStructure(*cmdBuffer, bottomLevelAccelerationStructures);
597
598		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &preTraceMemoryBarrier);
599
600		const TopLevelAccelerationStructure*			topLevelAccelerationStructurePtr		= topLevelAccelerationStructure.get();
601		VkWriteDescriptorSetAccelerationStructureKHR	accelerationStructureWriteDescriptorSet	=
602		{
603			VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,	//  VkStructureType						sType;
604			DE_NULL,															//  const void*							pNext;
605			1u,																	//  deUint32							accelerationStructureCount;
606			topLevelAccelerationStructurePtr->getPtr(),							//  const VkAccelerationStructureKHR*	pAccelerationStructures;
607		};
608
609		DescriptorSetUpdateBuilder()
610			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
611			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
612			.update(vkd, device);
613
614		vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet.get(), 0, DE_NULL);
615
616		vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
617
618		cmdTraceRays(vkd,
619			*cmdBuffer,
620			&raygenShaderBindingTableRegion,
621			&missShaderBindingTableRegion,
622			&hitShaderBindingTableRegion,
623			&callableShaderBindingTableRegion,
624			m_data.width, m_data.height, 1);
625
626		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, ALL_RAY_TRACING_STAGES, VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
627
628		vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **buffer, 1u, &bufferImageRegion);
629
630		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT, &postCopyMemoryBarrier);
631	}
632	endCommandBuffer(vkd, *cmdBuffer);
633
634	submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
635
636	invalidateMappedMemoryRange(vkd, device, buffer->getAllocation().getMemory(), buffer->getAllocation().getOffset(), pixelCount * sizeof(deUint32));
637
638	return buffer;
639}
640
641std::vector<deUint32> RayTracingComplexControlFlowInstance::getExpectedValues (void) const
642{
643	const deUint32				plainSize		= m_data.width * m_data.height;
644	const deUint32				plain8Ofs		= 8 * plainSize;
645	const struct PushConstants&	p				= m_pushConstants;
646	const deUint32				pushConstants[]	= { 0, m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
647	const deUint32				resultSize		= plainSize * m_depth;
648	const bool					fixed			= m_data.testOp == TEST_OP_REPORT_INTERSECTION;
649	std::vector<deUint32>		result			(resultSize, DEFAULT_CLEAR_VALUE);
650	deUint32					v0;
651	deUint32					v1;
652	deUint32					v2;
653	deUint32					v3;
654
655	switch (m_data.testType)
656	{
657		case TEST_TYPE_IF:
658		{
659			for (deUint32 id = 0; id < plainSize; ++id)
660			{
661				v2 = v3 = p.b;
662
663				if ((p.a & id) != 0)
664				{
665					v0 = p.c & id;
666					v1 = (p.d & id) + 1;
667
668					result[plain8Ofs + id] = v0;
669					if (!fixed) v0++;
670				}
671				else
672				{
673					v0 = p.d & id;
674					v1 = (p.c & id) + 1;
675
676					if (!fixed)
677					{
678						result[plain8Ofs + id] = v1;
679						v1++;
680					}
681					else
682						result[plain8Ofs + id] = v0;
683				}
684
685				result[id] = v0 + v1 + v2 + v3;
686			}
687
688			break;
689		}
690		case TEST_TYPE_LOOP:
691		{
692			for (deUint32 id = 0; id < plainSize; ++id)
693			{
694				result[id] = 0;
695
696				v1 = v3 = p.b;
697
698				for (deUint32 n = 0; n < p.a; n++)
699				{
700					v0 = (p.c & id) + n;
701
702					result[((n % 8) + 8) * plainSize + id] = v0;
703					if (!fixed) v0++;
704
705					result[id] += v0 + v1 + v3;
706				}
707			}
708
709			break;
710		}
711		case TEST_TYPE_SWITCH:
712		{
713			for (deUint32 id = 0; id < plainSize; ++id)
714			{
715				switch (p.a & id)
716				{
717					case 0: { v1 = v2 = v3 = p.b; v0 = p.c & id; break; }
718					case 1: { v0 = v2 = v3 = p.b; v1 = p.c & id; break; }
719					case 2: { v0 = v1 = v3 = p.b; v2 = p.c & id; break; }
720					case 3: { v0 = v1 = v2 = p.b; v3 = p.c & id; break; }
721					default: { v0 = v1 = v2 = v3 = 0; break; }
722				}
723
724				if (!fixed)
725					result[plain8Ofs + id] = p.c & id;
726				else
727					result[plain8Ofs + id] = v0;
728
729				result[id] = v0 + v1 + v2 + v3;
730
731				if (!fixed) result[id]++;
732			}
733
734			break;
735		}
736		case TEST_TYPE_LOOP_DOUBLE_CALL:
737		{
738			for (deUint32 id = 0; id < plainSize; ++id)
739			{
740				result[id] = 0;
741
742				v3 = p.b;
743
744				for (deUint32 x = 0; x < p.a; x++)
745				{
746					v0 = (p.c & id) + x;
747					v1 = (p.d & id) + x + 1;
748
749					result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
750					if (!fixed) v0++;
751
752					if (!fixed)
753					{
754						result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
755						v1++;
756					}
757
758					result[id] += v0 + v1 + v3;
759				}
760			}
761
762			break;
763		}
764		case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
765		{
766			for (deUint32 id = 0; id < plainSize; ++id)
767			{
768				result[id] = 0;
769
770				v3 = p.a + p.b;
771
772				for (deUint32 x = 0; x < p.a; x++)
773				{
774					if ((x & p.b) != 0)
775					{
776						v0 = (p.c & id) + x;
777						v1 = (p.d & id) + x + 1;
778
779						result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
780						if (!fixed) v0++;
781
782						if (!fixed)
783						{
784							result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
785							v1++;
786						}
787
788						result[id] += v0 + v1 + v3;
789					}
790				}
791			}
792
793			break;
794		}
795		case TEST_TYPE_NESTED_LOOP:
796		{
797			for (deUint32 id = 0; id < plainSize; ++id)
798			{
799				result[id] = 0;
800
801				v1 = v3 = p.b;
802
803				for (deUint32 y = 0; y < p.a; y++)
804				for (deUint32 x = 0; x < p.a; x++)
805				{
806					const deUint32 n = x + y * p.a;
807
808					if ((n & p.d) != 0)
809					{
810						v0 = (p.c & id) + n;
811
812						result[((n % 8) + 8) * plainSize + id] = v0;
813						if (!fixed) v0++;
814
815						result[id] += v0 + v1 + v3;
816					}
817				}
818			}
819
820			break;
821		}
822		case TEST_TYPE_NESTED_LOOP_BEFORE:
823		{
824			for (deUint32 id = 0; id < plainSize; ++id)
825			{
826				result[id] = 0;
827
828				for (deUint32 y = 0; y < p.d; y++)
829				for (deUint32 x = 0; x < p.d; x++)
830				{
831					if (((x + y * p.a) & p.b) != 0)
832						result[id] += (x + y);
833				}
834
835				v1 = v3 = p.a;
836
837				for (deUint32 x = 0; x < p.b; x++)
838				{
839					if ((x & p.a) != 0)
840					{
841						v0 = p.c & id;
842
843						result[((x % 8) + 8) * plainSize + id] = v0;
844						if (!fixed) v0++;
845
846						result[id] += v0 + v1 + v3;
847					}
848				}
849			}
850
851			break;
852		}
853		case TEST_TYPE_NESTED_LOOP_AFTER:
854		{
855			for (deUint32 id = 0; id < plainSize; ++id)
856			{
857				result[id] = 0;
858
859				v1 = v3 = p.a;
860
861				for (deUint32 x = 0; x < p.b; x++)
862				{
863					if ((x & p.a) != 0)
864					{
865						v0 = p.c & id;
866
867						result[((x % 8) + 8) * plainSize + id] = v0;
868						if (!fixed) v0++;
869
870						result[id] += v0 + v1 + v3;
871					}
872				}
873
874				for (deUint32 y = 0; y < p.d; y++)
875				for (deUint32 x = 0; x < p.d; x++)
876				{
877					if (((x + y * p.a) & p.b) != 0)
878						result[id] += (x + y);
879				}
880			}
881
882			break;
883		}
884		case TEST_TYPE_FUNCTION_CALL:
885		{
886			deUint32 a[42];
887
888			for (deUint32 id = 0; id < plainSize; ++id)
889			{
890				deUint32 r = 0;
891				deUint32 i;
892
893				v0 = p.a & id;
894				v1 = v3 = p.d;
895
896				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
897					a[i] = p.c * i;
898
899				result[plain8Ofs + id] = v0;
900				if (!fixed) v0++;
901
902				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
903					r += a[i];
904
905				result[id] = (r + i) + v0 + v1 + v3;
906			}
907
908			break;
909		}
910		case TEST_TYPE_NESTED_FUNCTION_CALL:
911		{
912			deUint32 a[14];
913			deUint32 b[256];
914
915			for (deUint32 id = 0; id < plainSize; ++id)
916			{
917				deUint32 r = 0;
918				deUint32 i;
919				deUint32 t = 0;
920				deUint32 j;
921
922				v0 = p.a & id;
923				v3 = p.d;
924
925				for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
926					b[j] = p.c * j;
927
928				v1 = p.b;
929
930				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
931					a[i] = p.c * i;
932
933				result[plain8Ofs + id] = v0;
934				if (!fixed) v0++;
935
936				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
937					r += a[i];
938
939				for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
940					t += b[j];
941
942				result[id] = (r + i) + (t + j) + v0 + v1 + v3;
943			}
944
945			break;
946		}
947
948		default:
949			TCU_THROW(InternalError, "Unknown testType");
950	}
951
952	{
953		const deUint32	startOfs	= 7 * plainSize;
954
955		for (deUint32 n = 0; n < plainSize; ++n)
956			result[startOfs + n] = n;
957	}
958
959	for (deUint32 z = 1; z < DE_LENGTH_OF_ARRAY(pushConstants); ++z)
960	{
961		const deUint32	startOfs		= z * plainSize;
962		const deUint32	pushConstant	= pushConstants[z];
963
964		for (deUint32 n = 0; n < plainSize; ++n)
965			result[startOfs + n] = pushConstant;
966	}
967
968	return result;
969}
970
971tcu::TestStatus RayTracingComplexControlFlowInstance::iterate (void)
972{
973	const de::MovePtr<BufferWithMemory>	buffer		= runTest();
974	const deUint32*						bufferPtr	= (deUint32*)buffer->getAllocation().getHostPtr();
975	const vector<deUint32>				expected	= getExpectedValues();
976	tcu::TestLog&						log			= m_context.getTestContext().getLog();
977	deUint32							failures	= 0;
978	deUint32							pos			= 0;
979
980	for (deUint32 z = 0; z < m_depth; ++z)
981	for (deUint32 y = 0; y < m_data.height; ++y)
982	for (deUint32 x = 0; x < m_data.width; ++x)
983	{
984		if (bufferPtr[pos] != expected[pos])
985			failures++;
986
987		++pos;
988	}
989
990	if (failures != 0)
991	{
992		deUint32			pos0	= 0;
993		deUint32			pos1	= 0;
994		std::stringstream	css;
995
996		for (deUint32 z = 0; z < m_depth; ++z)
997		{
998			css << "z=" << z << std::endl;
999
1000			for (deUint32 y = 0; y < m_data.height; ++y)
1001			{
1002				for (deUint32 x = 0; x < m_data.width; ++x)
1003					css << std::setw(6) << bufferPtr[pos0++] << ' ';
1004
1005				css << "    ";
1006
1007				for (deUint32 x = 0; x < m_data.width; ++x)
1008					css << std::setw(6) << expected[pos1++] << ' ';
1009
1010				css << std::endl;
1011			}
1012
1013			css << std::endl;
1014		}
1015
1016		log << tcu::TestLog::Message << css.str() << tcu::TestLog::EndMessage;
1017	}
1018
1019	if (failures == 0)
1020		return tcu::TestStatus::pass("Pass");
1021	else
1022		return tcu::TestStatus::fail("failures=" + de::toString(failures));
1023}
1024
1025class ComplexControlFlowTestCase : public TestCase
1026{
1027	public:
1028										ComplexControlFlowTestCase	(tcu::TestContext& context, const char* name, const CaseDef data);
1029										~ComplexControlFlowTestCase	(void);
1030
1031	virtual	void						initPrograms				(SourceCollections& programCollection) const;
1032	virtual TestInstance*				createInstance				(Context& context) const;
1033	virtual void						checkSupport				(Context& context) const;
1034
1035private:
1036	static inline const std::string		getIntersectionPassthrough	(void);
1037	static inline const std::string		getMissPassthrough			(void);
1038	static inline const std::string		getHitPassthrough			(void);
1039
1040	CaseDef								m_data;
1041};
1042
1043ComplexControlFlowTestCase::ComplexControlFlowTestCase (tcu::TestContext& context, const char* name, const CaseDef data)
1044	: vkt::TestCase	(context, name)
1045	, m_data		(data)
1046{
1047}
1048
1049ComplexControlFlowTestCase::~ComplexControlFlowTestCase	(void)
1050{
1051}
1052
1053void ComplexControlFlowTestCase::checkSupport (Context& context) const
1054{
1055	context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
1056
1057	const VkPhysicalDeviceAccelerationStructureFeaturesKHR&	accelerationStructureFeaturesKHR = context.getAccelerationStructureFeatures();
1058
1059	if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
1060		TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
1061
1062	context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1063
1064	const VkPhysicalDeviceRayTracingPipelineFeaturesKHR&	rayTracingPipelineFeaturesKHR = context.getRayTracingPipelineFeatures();
1065
1066	if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE)
1067		TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1068
1069	const VkPhysicalDeviceRayTracingPipelinePropertiesKHR&	rayTracingPipelinePropertiesKHR = context.getRayTracingPipelineProperties();
1070
1071	if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
1072	{
1073		if (rayTracingPipelinePropertiesKHR.maxRayRecursionDepth < 2)
1074			TCU_THROW(NotSupportedError, "rayTracingPipelinePropertiesKHR.maxRayRecursionDepth is smaller than required");
1075	}
1076}
1077
1078
1079const std::string ComplexControlFlowTestCase::getIntersectionPassthrough (void)
1080{
1081	const std::string intersectionPassthrough =
1082		"#version 460 core\n"
1083		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1084		"#extension GL_EXT_ray_tracing : require\n"
1085		"hitAttributeEXT vec3 hitAttribute;\n"
1086		"\n"
1087		"void main()\n"
1088		"{\n"
1089		"  reportIntersectionEXT(0.95f, 0u);\n"
1090		"}\n";
1091
1092	return intersectionPassthrough;
1093}
1094
1095const std::string ComplexControlFlowTestCase::getMissPassthrough (void)
1096{
1097	const std::string missPassthrough =
1098		"#version 460 core\n"
1099		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1100		"#extension GL_EXT_ray_tracing : require\n"
1101		"layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1102		"\n"
1103		"void main()\n"
1104		"{\n"
1105		"}\n";
1106
1107	return missPassthrough;
1108}
1109
1110const std::string ComplexControlFlowTestCase::getHitPassthrough (void)
1111{
1112	const std::string hitPassthrough =
1113		"#version 460 core\n"
1114		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1115		"#extension GL_EXT_ray_tracing : require\n"
1116		"hitAttributeEXT vec3 attribs;\n"
1117		"layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1118		"\n"
1119		"void main()\n"
1120		"{\n"
1121		"}\n";
1122
1123	return hitPassthrough;
1124}
1125
1126void ComplexControlFlowTestCase::initPrograms (SourceCollections& programCollection) const
1127{
1128	const vk::ShaderBuildOptions	buildOptions			(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1129	const std::string				calleeMainPart			=
1130		"  uint z = (inValue.x % 8) + 8;\n"
1131		"  uint v = inValue.y;\n"
1132		"  uint n = gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y;\n"
1133		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, z), uvec4(v, 0, 0, 1));\n"
1134		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 7), uvec4(n, 0, 0, 1));\n";
1135	const std::string				idTemplate				= "$";
1136	const std::string				shaderCallInstruction	= (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)    ? "executeCallableEXT(0, " + idTemplate + ")"
1137															: (m_data.testOp == TEST_OP_TRACE_RAY)           ? "traceRayEXT(as, 0, 0xFF, p.hitOfs, 0, p.miss, vec3((gl_LaunchIDEXT.x) + vec3(0.5f)) / vec3(gl_LaunchSizeEXT), 1.0f, vec3(0.0f, 0.0f, 1.0f), 100.0f, " + idTemplate + ")"
1138															: (m_data.testOp == TEST_OP_REPORT_INTERSECTION) ? "reportIntersectionEXT(1.0f, 0u)"
1139															: "TEST_OP_NOT_IMPLEMENTED_FAILURE";
1140	std::string						declsPreMain			=
1141		"#version 460 core\n"
1142		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1143		"#extension GL_EXT_ray_tracing : require\n"
1144		"\n"
1145		"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1146		"layout(set = 0, binding = 1) uniform accelerationStructureEXT as;\n"
1147		"\n"
1148		"layout(push_constant) uniform TestParams\n"
1149		"{\n"
1150		"    uint a;\n"
1151		"    uint b;\n"
1152		"    uint c;\n"
1153		"    uint d;\n"
1154		"    uint hitOfs;\n"
1155		"    uint miss;\n"
1156		"} p;\n";
1157	std::string						declsInMainBeforeOp		=
1158		"  uint result = 0;\n"
1159		"  uint id = uint(gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y);\n";
1160	std::string						declsInMainAfterOp		=
1161		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 0), uvec4(result, 0, 0, 1));\n"
1162		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 1), uvec4(p.a, 0, 0, 1));\n"
1163		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 2), uvec4(p.b, 0, 0, 1));\n"
1164		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 3), uvec4(p.c, 0, 0, 1));\n"
1165		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 4), uvec4(p.d, 0, 0, 1));\n"
1166		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 5), uvec4(p.hitOfs, 0, 0, 1));\n"
1167		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 6), uvec4(p.miss, 0, 0, 1));\n";
1168	std::string						opInMain				= "";
1169	std::string						opPreMain				= "";
1170
1171	DE_ASSERT(!declsPreMain.empty() && PUSH_CONSTANTS_COUNT == 6);
1172
1173	switch (m_data.testType)
1174	{
1175		case TEST_TYPE_IF:
1176		{
1177			opInMain =
1178				"  v2 = v3 = uvec2(0, p.b);\n"
1179				"\n"
1180				"  if ((p.a & id) != 0)\n"
1181				"      { v0 = uvec2(0, p.c & id); v1 = uvec2(0, (p.d & id) + 1);" + replace(shaderCallInstruction, idTemplate, "0") + "; }\n"
1182				"  else\n"
1183				"      { v0 = uvec2(0, p.d & id); v1 = uvec2(0, (p.c & id) + 1);" + replace(shaderCallInstruction, idTemplate, "1") + "; }\n"
1184				"\n"
1185				"  result = v0.y + v1.y + v2.y + v3.y;\n";
1186
1187			break;
1188		}
1189		case TEST_TYPE_LOOP:
1190		{
1191			opInMain =
1192				"  v1 = v3 = uvec2(0, p.b);\n"
1193				"\n"
1194				"  for (uint x = 0; x < p.a; x++)\n"
1195				"  {\n"
1196				"    v0 = uvec2(x, (p.c & id) + x);\n"
1197				"    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1198				"    result += v0.y + v1.y + v3.y;\n"
1199				"  }\n";
1200
1201			break;
1202		}
1203		case TEST_TYPE_SWITCH:
1204		{
1205			opInMain =
1206				"  switch (p.a & id)\n"
1207				"  {\n"
1208				"    case 0: { v1 = v2 = v3 = uvec2(0, p.b); v0 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "0") + "; break; }\n"
1209				"    case 1: { v0 = v2 = v3 = uvec2(0, p.b); v1 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "1") + "; break; }\n"
1210				"    case 2: { v0 = v1 = v3 = uvec2(0, p.b); v2 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "2") + "; break; }\n"
1211				"    case 3: { v0 = v1 = v2 = uvec2(0, p.b); v3 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "3") + "; break; }\n"
1212				"    default: break;\n"
1213				"  }\n"
1214				"\n"
1215				"  result = v0.y + v1.y + v2.y + v3.y;\n";
1216
1217			break;
1218		}
1219		case TEST_TYPE_LOOP_DOUBLE_CALL:
1220		{
1221			opInMain =
1222				"  v3 = uvec2(0, p.b);\n"
1223				"  for (uint x = 0; x < p.a; x++)\n"
1224				"  {\n"
1225				"    v0 = uvec2(2 * x + 0, (p.c & id) + x);\n"
1226				"    v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1227				"    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1228				"    " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1229				"    result += v0.y + v1.y + v3.y;\n"
1230				"  }\n";
1231
1232			break;
1233		}
1234		case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
1235		{
1236			opInMain =
1237				"  v3 = uvec2(0, p.a + p.b);\n"
1238				"  for (uint x = 0; x < p.a; x++)\n"
1239				"    if ((x & p.b) != 0)\n"
1240				"    {\n"
1241				"      v0 = uvec2(2 * x + 0, (p.c & id) + x + 0);\n"
1242				"      v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1243				"      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1244				"      " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1245				"      result += v0.y + v1.y + v3.y;\n"
1246				"    }\n"
1247				"\n";
1248
1249			break;
1250		}
1251		case TEST_TYPE_NESTED_LOOP:
1252		{
1253			opInMain =
1254				"  v1 = v3 = uvec2(0, p.b);\n"
1255				"  for (uint y = 0; y < p.a; y++)\n"
1256				"  for (uint x = 0; x < p.a; x++)\n"
1257				"  {\n"
1258				"    uint n = x + y * p.a;\n"
1259				"    if ((n & p.d) != 0)\n"
1260				"    {\n"
1261				"      v0 = uvec2(n, (p.c & id) + (x + y * p.a));\n"
1262				"      "+ replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1263				"      result += v0.y + v1.y + v3.y;\n"
1264				"    }\n"
1265				"  }\n"
1266				"\n";
1267
1268			break;
1269		}
1270		case TEST_TYPE_NESTED_LOOP_BEFORE:
1271		{
1272			opInMain =
1273				"  for (uint y = 0; y < p.d; y++)\n"
1274				"  for (uint x = 0; x < p.d; x++)\n"
1275				"    if (((x + y * p.a) & p.b) != 0)\n"
1276				"      result += (x + y);\n"
1277				"\n"
1278				"  v1 = v3 = uvec2(0, p.a);\n"
1279				"\n"
1280				"  for (uint x = 0; x < p.b; x++)\n"
1281				"    if ((x & p.a) != 0)\n"
1282				"    {\n"
1283				"      v0 = uvec2(x, p.c & id);\n"
1284				"      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1285				"      result += v0.y + v1.y + v3.y;\n"
1286				"    }\n";
1287
1288			break;
1289		}
1290		case TEST_TYPE_NESTED_LOOP_AFTER:
1291		{
1292			opInMain =
1293				"  v1 = v3 = uvec2(0, p.a); \n"
1294				"  for (uint x = 0; x < p.b; x++)\n"
1295				"    if ((x & p.a) != 0)\n"
1296				"    {\n"
1297				"      v0 = uvec2(x, p.c & id);\n"
1298				"      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1299				"      result += v0.y + v1.y + v3.y;\n"
1300				"    }\n"
1301				"\n"
1302				"  for (uint y = 0; y < p.d; y++)\n"
1303				"  for (uint x = 0; x < p.d; x++)\n"
1304				"    if (((x + y * p.a) & p.b) != 0)\n"
1305				"      result += x + y;\n";
1306
1307			break;
1308		}
1309		case TEST_TYPE_FUNCTION_CALL:
1310		{
1311			opPreMain =
1312				"uint f1(void)\n"
1313				"{\n"
1314				"  uint i, r = 0;\n"
1315				"  uint a[42];\n"
1316				"\n"
1317				"  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1318				"\n"
1319				"  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1320				"\n"
1321				"  for (i = 0; i < a.length(); i++) r += a[i];\n"
1322				"\n"
1323				"  return r + i;\n"
1324				"}\n";
1325			opInMain =
1326				"  v0 = uvec2(0, p.a & id); v1 = v3 = uvec2(0, p.d);\n"
1327				"  result = f1() + v0.y + v1.y + v3.y;\n";
1328
1329			break;
1330		}
1331		case TEST_TYPE_NESTED_FUNCTION_CALL:
1332		{
1333			opPreMain =
1334				"uint f0(void)\n"
1335				"{\n"
1336				"  uint i, r = 0;\n"
1337				"  uint a[14];\n"
1338				"\n"
1339				"  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1340				"\n"
1341				"  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1342				"\n"
1343				"  for (i = 0; i < a.length(); i++) r += a[i];\n"
1344				"\n"
1345				"  return r + i;\n"
1346				"}\n"
1347				"\n"
1348				"uint f1(void)\n"
1349				"{\n"
1350				"  uint j, t = 0;\n"
1351				"  uint b[256];\n"
1352				"\n"
1353				"  for (j = 0; j < b.length(); j++) b[j] = p.c * j;\n"
1354				"\n"
1355				"  v1 = uvec2(0, p.b);\n"
1356				"\n"
1357				"  t += f0();\n"
1358				"\n"
1359				"  for (j = 0; j < b.length(); j++) t += b[j];\n"
1360				"\n"
1361				"  return t + j;\n"
1362				"}\n";
1363			opInMain =
1364				"  v0 = uvec2(0, p.a & id); v3 = uvec2(0, p.d);\n"
1365				"  result = f1() + v0.y + v1.y + v3.y;\n";
1366
1367			break;
1368		}
1369
1370		default:
1371			TCU_THROW(InternalError, "Unknown testType");
1372	}
1373
1374	if (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)
1375	{
1376		const std::string	calleeShader			=
1377			"#version 460 core\n"
1378			"#extension GL_EXT_nonuniform_qualifier : enable\n"
1379			"#extension GL_EXT_ray_tracing : require\n"
1380			"\n"
1381			"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1382			"layout(location = 0) callableDataInEXT uvec2 inValue;\n"
1383			"\n"
1384			"void main()\n"
1385			"{\n"
1386			+ calleeMainPart +
1387			"  inValue.y++;\n"
1388			"}\n";
1389
1390		declsPreMain +=
1391			"layout(location = 0) callableDataEXT uvec2 v0;\n"
1392			"layout(location = 1) callableDataEXT uvec2 v1;\n"
1393			"layout(location = 2) callableDataEXT uvec2 v2;\n"
1394			"layout(location = 3) callableDataEXT uvec2 v3;\n"
1395			"\n";
1396
1397		switch (m_data.stage)
1398		{
1399			case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1400			{
1401				std::stringstream css;
1402				css << declsPreMain
1403					<< opPreMain
1404					<< "\n"
1405					<< "void main()\n"
1406					<< "{\n"
1407					<< declsInMainBeforeOp
1408					<< opInMain // executeCallableEXT
1409					<< declsInMainAfterOp
1410					<< "}\n";
1411
1412				programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1413				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1414
1415				break;
1416			}
1417
1418			case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1419			{
1420				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1421
1422				std::stringstream css;
1423				css << declsPreMain
1424					<< "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1425					<< "hitAttributeEXT vec3 attribs;\n"
1426					<< "\n"
1427					<< opPreMain
1428					<< "\n"
1429					<< "void main()\n"
1430					<< "{\n"
1431					<< declsInMainBeforeOp
1432					<< opInMain // executeCallableEXT
1433					<< declsInMainAfterOp
1434					<< "}\n";
1435
1436				programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1437				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1438
1439				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1440				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1441				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1442
1443				break;
1444			}
1445
1446			case VK_SHADER_STAGE_MISS_BIT_KHR:
1447			{
1448				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1449
1450				std::stringstream css;
1451				css << declsPreMain
1452					<< opPreMain
1453					<< "\n"
1454					<< "void main()\n"
1455					<< "{\n"
1456					<< declsInMainBeforeOp
1457					<< opInMain // executeCallableEXT
1458					<< declsInMainAfterOp
1459					<< "}\n";
1460
1461				programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1462				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1463
1464				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1465				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1466				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1467
1468				break;
1469			}
1470
1471			case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1472			{
1473				{
1474					std::stringstream css;
1475					css << "#version 460 core\n"
1476						<< "#extension GL_EXT_nonuniform_qualifier : enable\n"
1477						<< "#extension GL_EXT_ray_tracing : require\n"
1478						<< "\n"
1479						<< "layout(location = 4) callableDataEXT float dummy;\n"
1480						<< "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1481						<< "\n"
1482						<< "void main()\n"
1483						<< "{\n"
1484						<< "  executeCallableEXT(1, 4);\n"
1485						<< "}\n";
1486
1487					programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1488				}
1489
1490				{
1491					std::stringstream css;
1492					css << declsPreMain
1493						<< "layout(location = 4) callableDataInEXT float dummyIn;\n"
1494						<< opPreMain
1495						<< "\n"
1496						<< "void main()\n"
1497						<< "{\n"
1498						<< declsInMainBeforeOp
1499						<< opInMain // executeCallableEXT
1500						<< declsInMainAfterOp
1501						<< "}\n";
1502
1503					programCollection.glslSources.add("call") << glu::CallableSource(css.str()) << buildOptions;
1504				}
1505
1506				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1507
1508				break;
1509			}
1510
1511			default:
1512				TCU_THROW(InternalError, "Unknown stage");
1513		}
1514	}
1515	else if (m_data.testOp == TEST_OP_TRACE_RAY)
1516	{
1517		const std::string	missShader	=
1518			"#version 460 core\n"
1519			"#extension GL_EXT_nonuniform_qualifier : enable\n"
1520			"#extension GL_EXT_ray_tracing : require\n"
1521			"\n"
1522			"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1523			"layout(location = 0) rayPayloadInEXT uvec2 inValue;\n"
1524			"\n"
1525			"void main()\n"
1526			"{\n"
1527			+ calleeMainPart +
1528			"  inValue.y++;\n"
1529			"}\n";
1530
1531		declsPreMain +=
1532			"layout(location = 0) rayPayloadEXT uvec2 v0;\n"
1533			"layout(location = 1) rayPayloadEXT uvec2 v1;\n"
1534			"layout(location = 2) rayPayloadEXT uvec2 v2;\n"
1535			"layout(location = 3) rayPayloadEXT uvec2 v3;\n";
1536
1537		switch (m_data.stage)
1538		{
1539			case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1540			{
1541				std::stringstream css;
1542				css << declsPreMain
1543					<< opPreMain
1544					<< "\n"
1545					<< "void main()\n"
1546					<< "{\n"
1547					<< declsInMainBeforeOp
1548					<< opInMain // traceRayEXT
1549					<< declsInMainAfterOp
1550					<< "}\n";
1551
1552				programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1553
1554				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1555				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1556				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1557				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1558
1559				programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1560				programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1561				programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1562				programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1563
1564				break;
1565			}
1566
1567			case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1568			{
1569				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1570
1571				std::stringstream css;
1572				css << declsPreMain
1573					<< opPreMain
1574					<< "\n"
1575					<< "void main()\n"
1576					<< "{\n"
1577					<< declsInMainBeforeOp
1578					<< opInMain // traceRayEXT
1579					<< declsInMainAfterOp
1580					<< "}\n";
1581
1582				programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1583
1584				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1585				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1586				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1587
1588				programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1589				programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1590				programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1591				programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1592
1593				break;
1594			}
1595
1596			case VK_SHADER_STAGE_MISS_BIT_KHR:
1597			{
1598				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1599
1600				std::stringstream css;
1601				css << declsPreMain
1602					<< opPreMain
1603					<< "\n"
1604					<< "void main()\n"
1605					<< "{\n"
1606					<< declsInMainBeforeOp
1607					<< opInMain // traceRayEXT
1608					<< declsInMainAfterOp
1609					<< "}\n";
1610
1611				programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1612
1613				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1614				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1615				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1616
1617				programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1618				programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1619				programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1620				programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1621
1622				break;
1623			}
1624
1625			default:
1626				TCU_THROW(InternalError, "Unknown stage");
1627		}
1628	}
1629	else if (m_data.testOp == TEST_OP_REPORT_INTERSECTION)
1630	{
1631		const std::string	anyHitShader		=
1632			"#version 460 core\n"
1633			"#extension GL_EXT_nonuniform_qualifier : enable\n"
1634			"#extension GL_EXT_ray_tracing : require\n"
1635			"\n"
1636			"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1637			"hitAttributeEXT block { uvec2 inValue; };\n"
1638			"\n"
1639			"void main()\n"
1640			"{\n"
1641			+ calleeMainPart +
1642			"}\n";
1643
1644		declsPreMain +=
1645			"hitAttributeEXT block { uvec2 v0; };\n"
1646			"uvec2 v1;\n"
1647			"uvec2 v2;\n"
1648			"uvec2 v3;\n";
1649
1650		switch (m_data.stage)
1651		{
1652			case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
1653			{
1654				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1655
1656				std::stringstream css;
1657				css << declsPreMain
1658					<< opPreMain
1659					<< "\n"
1660					<< "void main()\n"
1661					<< "{\n"
1662					<< declsInMainBeforeOp
1663					<< opInMain // reportIntersectionEXT
1664					<< declsInMainAfterOp
1665					<< "}\n";
1666
1667				programCollection.glslSources.add("sect") << glu::IntersectionSource(css.str()) << buildOptions;
1668				programCollection.glslSources.add("ahit") << glu::AnyHitSource(anyHitShader) << buildOptions;
1669
1670				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1671				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1672
1673				break;
1674			}
1675
1676			default:
1677				TCU_THROW(InternalError, "Unknown stage");
1678		}
1679	}
1680	else
1681	{
1682		TCU_THROW(InternalError, "Unknown operation");
1683	}
1684}
1685
1686TestInstance* ComplexControlFlowTestCase::createInstance (Context& context) const
1687{
1688	return new RayTracingComplexControlFlowInstance(context, m_data);
1689}
1690
1691}	// anonymous
1692
1693tcu::TestCaseGroup*	createComplexControlFlowTests (tcu::TestContext& testCtx)
1694{
1695	const VkShaderStageFlagBits	R	= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
1696	const VkShaderStageFlagBits	A	= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
1697	const VkShaderStageFlagBits	C	= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
1698	const VkShaderStageFlagBits	M	= VK_SHADER_STAGE_MISS_BIT_KHR;
1699	const VkShaderStageFlagBits	I	= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
1700	const VkShaderStageFlagBits	L	= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
1701
1702	DE_UNREF(A);
1703
1704	static const struct
1705	{
1706		const char*				name;
1707		VkShaderStageFlagBits	stage;
1708	}
1709	testStages[]
1710	{
1711		{ "rgen", VK_SHADER_STAGE_RAYGEN_BIT_KHR		},
1712		{ "chit", VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR	},
1713		{ "ahit", VK_SHADER_STAGE_ANY_HIT_BIT_KHR		},
1714		{ "sect", VK_SHADER_STAGE_INTERSECTION_BIT_KHR	},
1715		{ "miss", VK_SHADER_STAGE_MISS_BIT_KHR			},
1716		{ "call", VK_SHADER_STAGE_CALLABLE_BIT_KHR		},
1717	};
1718	static const struct
1719	{
1720		const char*			name;
1721		TestOp				op;
1722		VkShaderStageFlags	applicableInStages;
1723	}
1724	testOps[]
1725	{
1726		{ "execute_callable",		TEST_OP_EXECUTE_CALLABLE,		R |    C | M     | L },
1727		{ "trace_ray",				TEST_OP_TRACE_RAY,				R |    C | M         },
1728		{ "report_intersection",	TEST_OP_REPORT_INTERSECTION,	               I     },
1729	};
1730	static const struct
1731	{
1732		const char*	name;
1733		TestType	testType;
1734	}
1735	testTypes[]
1736	{
1737		{ "if",							TEST_TYPE_IF						},
1738		{ "loop",						TEST_TYPE_LOOP						},
1739		{ "switch",						TEST_TYPE_SWITCH					},
1740		{ "loop_double_call",			TEST_TYPE_LOOP_DOUBLE_CALL			},
1741		{ "loop_double_call_sparse",	TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE	},
1742		{ "nested_loop",				TEST_TYPE_NESTED_LOOP				},
1743		{ "nested_loop_loop_before",	TEST_TYPE_NESTED_LOOP_BEFORE		},
1744		{ "nested_loop_loop_after",		TEST_TYPE_NESTED_LOOP_AFTER			},
1745		{ "function_call",				TEST_TYPE_FUNCTION_CALL				},
1746		{ "nested_function_call",		TEST_TYPE_NESTED_FUNCTION_CALL		},
1747	};
1748
1749	// Ray tracing complex control flow tests
1750	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "complexcontrolflow"));
1751
1752	for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(testTypes); ++testTypeNdx)
1753	{
1754		const TestType					testType		= testTypes[testTypeNdx].testType;
1755		de::MovePtr<tcu::TestCaseGroup> testTypeGroup	(new tcu::TestCaseGroup(testCtx, testTypes[testTypeNdx].name));
1756
1757		for (size_t testOpNdx = 0; testOpNdx < DE_LENGTH_OF_ARRAY(testOps); ++testOpNdx)
1758		{
1759			const TestOp					testOp		= testOps[testOpNdx].op;
1760			de::MovePtr<tcu::TestCaseGroup> testOpGroup	(new tcu::TestCaseGroup(testCtx, testOps[testOpNdx].name));
1761
1762			for (size_t testStagesNdx = 0; testStagesNdx < DE_LENGTH_OF_ARRAY(testStages); ++testStagesNdx)
1763			{
1764				const VkShaderStageFlagBits	testStage				= testStages[testStagesNdx].stage;
1765				const std::string			testName				= de::toString(testStages[testStagesNdx].name);
1766				const deUint32				width					= 4u;
1767				const deUint32				height					= 4u;
1768				const CaseDef				caseDef					=
1769				{
1770					testType,				//  TestType				testType;
1771					testOp,					//  TestOp					testOp;
1772					testStage,				//  VkShaderStageFlagBits	stage;
1773					width,					//  deUint32				width;
1774					height,					//  deUint32				height;
1775				};
1776
1777				if ((testOps[testOpNdx].applicableInStages & static_cast<VkShaderStageFlags>(testStage)) == 0)
1778					continue;
1779
1780				testOpGroup->addChild(new ComplexControlFlowTestCase(testCtx, testName.c_str(), caseDef));
1781			}
1782
1783			testTypeGroup->addChild(testOpGroup.release());
1784		}
1785
1786		group->addChild(testTypeGroup.release());
1787	}
1788
1789	return group.release();
1790}
1791
1792}	// RayTracing
1793}	// vkt
1794