1/*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2020 The Khronos Group Inc.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *	  http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 *//*!
20 * \file
21 * \brief Ray Tracing Callable Shader tests
22 *//*--------------------------------------------------------------------*/
23
24#include "vktRayTracingCallableShadersTests.hpp"
25
26#include "vkDefs.hpp"
27
28#include "vktTestCase.hpp"
29#include "vktTestGroupUtil.hpp"
30#include "vkCmdUtil.hpp"
31#include "vkObjUtil.hpp"
32#include "vkBuilderUtil.hpp"
33#include "vkBarrierUtil.hpp"
34#include "vkBufferWithMemory.hpp"
35#include "vkImageWithMemory.hpp"
36#include "vkTypeUtil.hpp"
37#include "vkImageUtil.hpp"
38#include "deRandom.hpp"
39#include "tcuSurface.hpp"
40#include "tcuTexture.hpp"
41#include "tcuTextureUtil.hpp"
42#include "tcuTestLog.hpp"
43#include "tcuImageCompare.hpp"
44
45#include "vkRayTracingUtil.hpp"
46
47namespace vkt
48{
49namespace RayTracing
50{
51namespace
52{
53using namespace vk;
54using namespace vkt;
55
56static const VkFlags	ALL_RAY_TRACING_STAGES	= VK_SHADER_STAGE_RAYGEN_BIT_KHR
57												| VK_SHADER_STAGE_ANY_HIT_BIT_KHR
58												| VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
59												| VK_SHADER_STAGE_MISS_BIT_KHR
60												| VK_SHADER_STAGE_INTERSECTION_BIT_KHR
61												| VK_SHADER_STAGE_CALLABLE_BIT_KHR;
62
63enum CallableShaderTestType
64{
65	CSTT_RGEN_CALL		= 0,
66	CSTT_RGEN_CALL_CALL	= 1,
67	CSTT_HIT_CALL		= 2,
68	CSTT_RGEN_MULTICALL	= 3,
69	CSTT_COUNT
70};
71
72const deUint32			TEST_WIDTH			= 8;
73const deUint32			TEST_HEIGHT			= 8;
74
75struct TestParams;
76
77class TestConfiguration
78{
79public:
80	virtual std::vector<de::SharedPtr<BottomLevelAccelerationStructure>>	initBottomAccelerationStructures	(Context&							context,
81																												 TestParams&						testParams) = 0;
82	virtual de::MovePtr<TopLevelAccelerationStructure>						initTopAccelerationStructure		(Context&							context,
83																												 TestParams&						testParams,
84																												 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures) = 0;
85	virtual void															initRayTracingShaders				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
86																												 Context&							context,
87																												TestParams&							testParams) = 0;
88	virtual void															initShaderBindingTables				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
89																												 Context&							context,
90																												 TestParams&						testParams,
91																												 VkPipeline							pipeline,
92																												 deUint32							shaderGroupHandleSize,
93																												 deUint32							shaderGroupBaseAlignment,
94																												 de::MovePtr<BufferWithMemory>&		raygenShaderBindingTable,
95																												 de::MovePtr<BufferWithMemory>&		hitShaderBindingTable,
96																												 de::MovePtr<BufferWithMemory>&		missShaderBindingTable,
97																												 de::MovePtr<BufferWithMemory>&		callableShaderBindingTable,
98																												 VkStridedDeviceAddressRegionKHR&	raygenShaderBindingTableRegion,
99																												 VkStridedDeviceAddressRegionKHR&	hitShaderBindingTableRegion,
100																												 VkStridedDeviceAddressRegionKHR&	missShaderBindingTableRegion,
101																												 VkStridedDeviceAddressRegionKHR&	callableShaderBindingTableRegion) = 0;
102	virtual bool															verifyImage							(BufferWithMemory*					resultBuffer,
103																												 Context&							context,
104																												 TestParams&						testParams) = 0;
105	virtual VkFormat														getResultImageFormat				() = 0;
106	virtual size_t															getResultImageFormatSize			() = 0;
107	virtual VkClearValue													getClearValue						() = 0;
108};
109
110struct TestParams
111{
112	deUint32							width;
113	deUint32							height;
114	CallableShaderTestType				callableShaderTestType;
115	de::SharedPtr<TestConfiguration>	testConfiguration;
116    glu::ShaderType						invokingShader;
117	bool								multipleInvocations;
118};
119
120deUint32 getShaderGroupHandleSize (const InstanceInterface&	vki,
121								   const VkPhysicalDevice	physicalDevice)
122{
123	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
124
125	rayTracingPropertiesKHR	= makeRayTracingProperties(vki, physicalDevice);
126	return rayTracingPropertiesKHR->getShaderGroupHandleSize();
127}
128
129deUint32 getShaderGroupBaseAlignment (const InstanceInterface&	vki,
130									  const VkPhysicalDevice	physicalDevice)
131{
132	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
133
134	rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
135	return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
136}
137
138VkImageCreateInfo makeImageCreateInfo (deUint32 width, deUint32 height, VkFormat format)
139{
140	const VkImageCreateInfo			imageCreateInfo			=
141	{
142		VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,																// VkStructureType			sType;
143		DE_NULL,																							// const void*				pNext;
144		(VkImageCreateFlags)0u,																				// VkImageCreateFlags		flags;
145		VK_IMAGE_TYPE_2D,																					// VkImageType				imageType;
146		format,																								// VkFormat					format;
147		makeExtent3D(width, height, 1),																		// VkExtent3D				extent;
148		1u,																									// deUint32					mipLevels;
149		1u,																									// deUint32					arrayLayers;
150		VK_SAMPLE_COUNT_1_BIT,																				// VkSampleCountFlagBits	samples;
151		VK_IMAGE_TILING_OPTIMAL,																			// VkImageTiling			tiling;
152		VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT,		// VkImageUsageFlags		usage;
153		VK_SHARING_MODE_EXCLUSIVE,																			// VkSharingMode			sharingMode;
154		0u,																									// deUint32					queueFamilyIndexCount;
155		DE_NULL,																							// const deUint32*			pQueueFamilyIndices;
156		VK_IMAGE_LAYOUT_UNDEFINED																			// VkImageLayout			initialLayout;
157	};
158
159	return imageCreateInfo;
160}
161
162class SingleSquareConfiguration : public TestConfiguration
163{
164public:
165	std::vector<de::SharedPtr<BottomLevelAccelerationStructure>>	initBottomAccelerationStructures	(Context&							context,
166																										 TestParams&						testParams) override;
167	de::MovePtr<TopLevelAccelerationStructure>						initTopAccelerationStructure		(Context&							context,
168																										 TestParams&						testParams,
169																										 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures) override;
170	void															initRayTracingShaders				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
171																										 Context&							context,
172																										 TestParams&						testParams) override;
173	void															initShaderBindingTables				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
174																										 Context&							context,
175																										 TestParams&						testParams,
176																										 VkPipeline							pipeline,
177																										 deUint32							shaderGroupHandleSize,
178																										 deUint32							shaderGroupBaseAlignment,
179																										 de::MovePtr<BufferWithMemory>&		raygenShaderBindingTable,
180																										 de::MovePtr<BufferWithMemory>&		hitShaderBindingTable,
181																										 de::MovePtr<BufferWithMemory>&		missShaderBindingTable,
182																										 de::MovePtr<BufferWithMemory>&		callableShaderBindingTable,
183																										 VkStridedDeviceAddressRegionKHR&	raygenShaderBindingTableRegion,
184																										 VkStridedDeviceAddressRegionKHR&	hitShaderBindingTableRegion,
185																										 VkStridedDeviceAddressRegionKHR&	missShaderBindingTableRegion,
186																										 VkStridedDeviceAddressRegionKHR&	callableShaderBindingTableRegion) override;
187	bool															verifyImage							(BufferWithMemory*					resultBuffer,
188																										 Context&							context,
189																										 TestParams&						testParams) override;
190	VkFormat														getResultImageFormat				() override;
191	size_t															getResultImageFormatSize			() override;
192	VkClearValue													getClearValue						() override;
193};
194
195std::vector<de::SharedPtr<BottomLevelAccelerationStructure> > SingleSquareConfiguration::initBottomAccelerationStructures (Context&			context,
196																														   TestParams&		testParams)
197{
198	DE_UNREF(context);
199
200	tcu::Vec3 v0(1.0, float(testParams.height) - 1.0f, 0.0);
201	tcu::Vec3 v1(1.0, 1.0, 0.0);
202	tcu::Vec3 v2(float(testParams.width) - 1.0f, float(testParams.height) - 1.0f, 0.0);
203	tcu::Vec3 v3(float(testParams.width) - 1.0f, 1.0, 0.0);
204
205	std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >	result;
206	de::MovePtr<BottomLevelAccelerationStructure>					bottomLevelAccelerationStructure	= makeBottomLevelAccelerationStructure();
207	bottomLevelAccelerationStructure->setGeometryCount(1);
208
209	de::SharedPtr<RaytracedGeometryBase> geometry = makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
210	geometry->addVertex(v0);
211	geometry->addVertex(v1);
212	geometry->addVertex(v2);
213	geometry->addVertex(v2);
214	geometry->addVertex(v1);
215	geometry->addVertex(v3);
216	bottomLevelAccelerationStructure->addGeometry(geometry);
217
218	result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
219
220	return result;
221}
222
223de::MovePtr<TopLevelAccelerationStructure> SingleSquareConfiguration::initTopAccelerationStructure (Context&		context,
224																									TestParams&		testParams,
225																									std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >& bottomLevelAccelerationStructures)
226{
227	DE_UNREF(context);
228	DE_UNREF(testParams);
229
230	de::MovePtr<TopLevelAccelerationStructure>	result						= makeTopLevelAccelerationStructure();
231	result->setInstanceCount(1);
232	result->addInstance(bottomLevelAccelerationStructures[0]);
233
234	return result;
235}
236
237void SingleSquareConfiguration::initRayTracingShaders (de::MovePtr<RayTracingPipeline>&		rayTracingPipeline,
238													   Context&								context,
239													   TestParams&							testParams)
240{
241	const DeviceInterface&						vkd						= context.getDeviceInterface();
242	const VkDevice								device					= context.getDevice();
243
244	switch (testParams.callableShaderTestType)
245	{
246		case CSTT_RGEN_CALL:
247		{
248			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
249			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
250			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
251			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
252			break;
253		}
254		case CSTT_RGEN_CALL_CALL:
255		{
256			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
257			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
258			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
259			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_call"), 0), 3);
260			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 4);
261			break;
262		}
263		case CSTT_HIT_CALL:
264		{
265			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen"), 0), 0);
266			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit_call"), 0), 1);
267			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss_call"), 0), 2);
268			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
269			break;
270		}
271		case CSTT_RGEN_MULTICALL:
272		{
273			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_multicall"), 0), 0);
274			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
275			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
276			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
277			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_1"), 0), 4);
278			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_2"), 0), 5);
279			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_3"), 0), 6);
280			break;
281		}
282		default:
283			TCU_THROW(InternalError, "Wrong shader test type");
284	}
285}
286
287void SingleSquareConfiguration::initShaderBindingTables (de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
288														 Context&							context,
289														 TestParams&						testParams,
290														 VkPipeline							pipeline,
291														 deUint32							shaderGroupHandleSize,
292														 deUint32							shaderGroupBaseAlignment,
293														 de::MovePtr<BufferWithMemory>&		raygenShaderBindingTable,
294														 de::MovePtr<BufferWithMemory>&		hitShaderBindingTable,
295														 de::MovePtr<BufferWithMemory>&		missShaderBindingTable,
296														 de::MovePtr<BufferWithMemory>&		callableShaderBindingTable,
297														 VkStridedDeviceAddressRegionKHR&	raygenShaderBindingTableRegion,
298														 VkStridedDeviceAddressRegionKHR&	hitShaderBindingTableRegion,
299														 VkStridedDeviceAddressRegionKHR&	missShaderBindingTableRegion,
300														 VkStridedDeviceAddressRegionKHR&	callableShaderBindingTableRegion)
301{
302	const DeviceInterface&						vkd							= context.getDeviceInterface();
303	const VkDevice								device						= context.getDevice();
304	Allocator&									allocator					= context.getDefaultAllocator();
305
306	switch (testParams.callableShaderTestType)
307	{
308		case CSTT_RGEN_CALL:
309		{
310			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
311			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
312			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
313			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
314
315			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
316			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
317			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
318			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
319			break;
320		}
321		case CSTT_RGEN_CALL_CALL:
322		{
323			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
324			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
325			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
326			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 2);
327
328			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
329			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
330			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
331			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, 2*shaderGroupHandleSize);
332			break;
333		}
334		case CSTT_HIT_CALL:
335		{
336			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
337			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
338			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
339			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
340
341			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
342			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
343			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
344			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
345			break;
346		}
347		case CSTT_RGEN_MULTICALL:
348		{
349			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
350			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
351			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
352			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 4);
353
354			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
355			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
356			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
357			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, 4*shaderGroupHandleSize);
358			break;
359		}
360		default:
361			TCU_THROW(InternalError, "Wrong shader test type");
362	}
363}
364
365bool SingleSquareConfiguration::verifyImage (BufferWithMemory* resultBuffer, Context& context, TestParams& testParams)
366{
367	// create result image
368	tcu::TextureFormat			imageFormat						= vk::mapVkFormat(getResultImageFormat());
369	tcu::ConstPixelBufferAccess	resultAccess(imageFormat, testParams.width, testParams.height, 1, resultBuffer->getAllocation().getHostPtr());
370
371	// create reference image
372	std::vector<deUint32>		reference(testParams.width * testParams.height);
373	tcu::PixelBufferAccess		referenceAccess(imageFormat, testParams.width, testParams.height, 1, reference.data());
374
375	tcu::UVec4 missValue, hitValue;
376
377	// clear reference image with hit and miss values ( hit works only for tests calling traceRayEXT in rgen shader )
378	switch (testParams.callableShaderTestType)
379	{
380		case CSTT_RGEN_CALL:
381			missValue	= tcu::UVec4(1, 0, 0, 0);
382			hitValue	= tcu::UVec4(1, 0, 0, 0);
383			break;
384		case CSTT_RGEN_CALL_CALL:
385			missValue	= tcu::UVec4(1, 0, 0, 0);
386			hitValue	= tcu::UVec4(1, 0, 0, 0);
387			break;
388		case CSTT_HIT_CALL:
389			missValue	= tcu::UVec4(1, 0, 0, 0);
390			hitValue	= tcu::UVec4(2, 0, 0, 0);
391			break;
392		case CSTT_RGEN_MULTICALL:
393			missValue	= tcu::UVec4(16, 0, 0, 0);
394			hitValue	= tcu::UVec4(16, 0, 0, 0);
395			break;
396		default:
397			TCU_THROW(InternalError, "Wrong shader test type");
398	}
399
400	tcu::clear(referenceAccess, missValue);
401	for (deUint32 y = 1; y < testParams.width - 1; ++y)
402	for (deUint32 x = 1; x < testParams.height - 1; ++x)
403		referenceAccess.setPixel(hitValue, x, y);
404
405	// compare result and reference
406	return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess, resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
407}
408
409VkFormat SingleSquareConfiguration::getResultImageFormat ()
410{
411	return VK_FORMAT_R32_UINT;
412}
413
414size_t SingleSquareConfiguration::getResultImageFormatSize ()
415{
416	return sizeof(deUint32);
417}
418
419VkClearValue SingleSquareConfiguration::getClearValue ()
420{
421	return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
422}
423
424class CallableShaderTestCase : public TestCase
425{
426	public:
427							CallableShaderTestCase			(tcu::TestContext& context, const char* name, const TestParams data);
428							~CallableShaderTestCase			(void);
429
430	virtual void			checkSupport								(Context& context) const;
431	virtual	void			initPrograms								(SourceCollections& programCollection) const;
432	virtual TestInstance*	createInstance								(Context& context) const;
433private:
434	TestParams				m_data;
435};
436
437class CallableShaderTestInstance : public TestInstance
438{
439public:
440																	CallableShaderTestInstance	(Context& context, const TestParams& data);
441																	~CallableShaderTestInstance	(void);
442	tcu::TestStatus													iterate									(void);
443
444protected:
445	de::MovePtr<BufferWithMemory>									runTest									();
446private:
447	TestParams														m_data;
448};
449
450CallableShaderTestCase::CallableShaderTestCase (tcu::TestContext& context, const char* name, const TestParams data)
451	: vkt::TestCase	(context, name)
452	, m_data		(data)
453{
454}
455
456CallableShaderTestCase::~CallableShaderTestCase (void)
457{
458}
459
460void CallableShaderTestCase::checkSupport (Context& context) const
461{
462	context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
463	context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
464
465	const VkPhysicalDeviceRayTracingPipelineFeaturesKHR&	rayTracingPipelineFeaturesKHR		= context.getRayTracingPipelineFeatures();
466	if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE )
467		TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
468
469	const VkPhysicalDeviceAccelerationStructureFeaturesKHR&	accelerationStructureFeaturesKHR	= context.getAccelerationStructureFeatures();
470	if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
471		TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
472}
473
474void CallableShaderTestCase::initPrograms (SourceCollections& programCollection) const
475{
476	const vk::ShaderBuildOptions	buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
477	{
478		std::stringstream css;
479		css <<
480			"#version 460 core\n"
481			"#extension GL_EXT_ray_tracing : require\n"
482			"layout(location = 0) rayPayloadEXT uvec4 hitValue;\n"
483			"layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
484			"layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
485			"\n"
486			"void main()\n"
487			"{\n"
488			"  float tmin     = 0.0;\n"
489			"  float tmax     = 1.0;\n"
490			"  vec3  origin   = vec3(float(gl_LaunchIDEXT.x) + 0.5f, float(gl_LaunchIDEXT.y) + 0.5f, 0.5f);\n"
491			"  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
492			"  hitValue       = uvec4(0,0,0,0);\n"
493			"  traceRayEXT(topLevelAS, 0, 0xFF, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
494			"  imageStore(result, ivec2(gl_LaunchIDEXT.xy), hitValue);\n"
495			"}\n";
496		programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
497	}
498
499	{
500		std::stringstream css;
501		css <<
502			"#version 460 core\n"
503			"#extension GL_EXT_ray_tracing : require\n"
504			"layout(location = 0) callableDataEXT uvec4 value;\n"
505			"layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
506			"layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
507			"\n"
508			"void main()\n"
509			"{\n"
510			"  executeCallableEXT(0, 0);\n"
511			"  imageStore(result, ivec2(gl_LaunchIDEXT.xy), value);\n"
512			"}\n";
513		programCollection.glslSources.add("rgen_call") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
514	}
515
516	{
517		std::stringstream css;
518		css <<
519			"#version 460 core\n"
520			"#extension GL_EXT_ray_tracing : require\n"
521			"struct CallValue\n"
522			"{\n"
523			"  ivec4 a;\n"
524			"  vec4  b;\n"
525			"};\n"
526			"layout(location = 0) callableDataEXT uvec4 value0;\n"
527			"layout(location = 1) callableDataEXT uint value1;\n"
528			"layout(location = 2) callableDataEXT CallValue value2;\n"
529			"layout(location = 4) callableDataEXT vec3 value3;\n"
530			"layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
531			"layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
532			"\n"
533			"void main()\n"
534			"{\n"
535			"  executeCallableEXT(0, 0);\n"
536			"  executeCallableEXT(1, 1);\n"
537			"  executeCallableEXT(2, 2);\n"
538			"  executeCallableEXT(3, 4);\n"
539			"  uint resultValue = value0.x + value1 + value2.a.x * uint(floor(value2.b.y)) + uint(floor(value3.z));\n"
540			"  imageStore(result, ivec2(gl_LaunchIDEXT.xy), uvec4(resultValue, 0, 0, 0));\n"
541			"}\n";
542		programCollection.glslSources.add("rgen_multicall") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
543	}
544
545	{
546		std::stringstream css;
547		css <<
548			"#version 460 core\n"
549			"#extension GL_EXT_ray_tracing : require\n"
550			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
551			"void main()\n"
552			"{\n"
553			"  hitValue = uvec4(1,0,0,1);\n"
554			"}\n";
555
556		programCollection.glslSources.add("chit") << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
557	}
558
559	{
560		std::stringstream css;
561		css <<
562			"#version 460 core\n"
563			"#extension GL_EXT_ray_tracing : require\n"
564			"layout(location = 0) callableDataEXT uvec4 value;\n"
565			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
566			"void main()\n"
567			"{\n"
568			"  executeCallableEXT(0, 0);\n"
569			"  hitValue = value;\n"
570			"  hitValue.x = hitValue.x + 1;\n"
571			"}\n";
572
573		programCollection.glslSources.add("chit_call") << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
574	}
575
576	{
577		std::stringstream css;
578		css <<
579			"#version 460 core\n"
580			"#extension GL_EXT_ray_tracing : require\n"
581			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
582			"void main()\n"
583			"{\n"
584			"  hitValue = uvec4(0,0,0,1);\n"
585			"}\n";
586
587		programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
588	}
589
590	{
591		std::stringstream css;
592		css <<
593			"#version 460 core\n"
594			"#extension GL_EXT_ray_tracing : require\n"
595			"layout(location = 0) callableDataEXT uvec4 value;\n"
596			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
597			"void main()\n"
598			"{\n"
599			"  executeCallableEXT(0, 0);\n"
600			"  hitValue = value;\n"
601			"}\n";
602
603		programCollection.glslSources.add("miss_call") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
604	}
605
606	std::vector<std::string> callableDataDefinition =
607	{
608		"layout(location = 0) callableDataInEXT uvec4 result;\n",
609		"layout(location = 1) callableDataInEXT uint result;\n",
610		"struct CallValue\n{\n  ivec4 a;\n  vec4  b;\n};\nlayout(location = 2) callableDataInEXT CallValue result;\n",
611		"layout(location = 4) callableDataInEXT vec3 result;\n"
612	};
613
614	std::vector<std::string> callableDataComputation =
615	{
616		"  result = uvec4(1,0,0,1);\n",
617		"  result = 2;\n",
618		"  result.a = ivec4(3,0,0,1);\n  result.b = vec4(1.0, 3.2, 0.0, 1);\n",
619		"  result = vec3(0.0, 0.0, 4.3);\n",
620	};
621
622	for (deUint32 idx = 0; idx < callableDataDefinition.size(); ++idx)
623	{
624		std::stringstream css;
625		css <<
626			"#version 460 core\n"
627			"#extension GL_EXT_ray_tracing : require\n"
628			<< callableDataDefinition[idx] <<
629			"void main()\n"
630			"{\n"
631			<< callableDataComputation[idx] <<
632			"}\n";
633		std::stringstream csname;
634		csname << "call_" << idx;
635
636		programCollection.glslSources.add(csname.str()) << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
637	}
638
639	{
640		std::stringstream css;
641		css <<
642			"#version 460 core\n"
643			"#extension GL_EXT_ray_tracing : require\n"
644			"layout(location = 0) callableDataInEXT uvec4 result;\n"
645			"layout(location = 1) callableDataEXT uvec4 info;\n"
646			"void main()\n"
647			"{\n"
648			"  executeCallableEXT(1, 1);\n"
649			"  result = info;\n"
650			"}\n";
651
652		programCollection.glslSources.add("call_call") << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
653	}
654}
655
656TestInstance* CallableShaderTestCase::createInstance (Context& context) const
657{
658	return new CallableShaderTestInstance(context, m_data);
659}
660
661CallableShaderTestInstance::CallableShaderTestInstance (Context& context, const TestParams& data)
662	: vkt::TestInstance		(context)
663	, m_data				(data)
664{
665}
666
667CallableShaderTestInstance::~CallableShaderTestInstance (void)
668{
669}
670
671de::MovePtr<BufferWithMemory> CallableShaderTestInstance::runTest ()
672{
673	const InstanceInterface&			vki									= m_context.getInstanceInterface();
674	const DeviceInterface&				vkd									= m_context.getDeviceInterface();
675	const VkDevice						device								= m_context.getDevice();
676	const VkPhysicalDevice				physicalDevice						= m_context.getPhysicalDevice();
677	const deUint32						queueFamilyIndex					= m_context.getUniversalQueueFamilyIndex();
678	const VkQueue						queue								= m_context.getUniversalQueue();
679	Allocator&							allocator							= m_context.getDefaultAllocator();
680	const deUint32						pixelCount							= m_data.width * m_data.height * 1;
681
682	const Move<VkDescriptorSetLayout>	descriptorSetLayout					= DescriptorSetLayoutBuilder()
683																					.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
684																					.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
685																					.build(vkd, device);
686	const Move<VkDescriptorPool>		descriptorPool						= DescriptorPoolBuilder()
687																					.addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
688																					.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
689																					.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
690	const Move<VkDescriptorSet>			descriptorSet						= makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
691	const Move<VkPipelineLayout>		pipelineLayout						= makePipelineLayout(vkd, device, descriptorSetLayout.get());
692
693	de::MovePtr<RayTracingPipeline>		rayTracingPipeline					= de::newMovePtr<RayTracingPipeline>();
694	m_data.testConfiguration->initRayTracingShaders(rayTracingPipeline, m_context, m_data);
695	Move<VkPipeline>					pipeline							= rayTracingPipeline->createPipeline(vkd, device, *pipelineLayout);
696
697	de::MovePtr<BufferWithMemory>		raygenShaderBindingTable;
698	de::MovePtr<BufferWithMemory>		hitShaderBindingTable;
699	de::MovePtr<BufferWithMemory>		missShaderBindingTable;
700	de::MovePtr<BufferWithMemory>		callableShaderBindingTable;
701	VkStridedDeviceAddressRegionKHR		raygenShaderBindingTableRegion;
702	VkStridedDeviceAddressRegionKHR		hitShaderBindingTableRegion;
703	VkStridedDeviceAddressRegionKHR		missShaderBindingTableRegion;
704	VkStridedDeviceAddressRegionKHR		callableShaderBindingTableRegion;
705	m_data.testConfiguration->initShaderBindingTables(rayTracingPipeline, m_context, m_data, *pipeline, getShaderGroupHandleSize(vki, physicalDevice), getShaderGroupBaseAlignment(vki, physicalDevice), raygenShaderBindingTable, hitShaderBindingTable, missShaderBindingTable, callableShaderBindingTable, raygenShaderBindingTableRegion, hitShaderBindingTableRegion, missShaderBindingTableRegion, callableShaderBindingTableRegion);
706
707	const VkFormat						imageFormat							= m_data.testConfiguration->getResultImageFormat();
708	const VkImageCreateInfo				imageCreateInfo						= makeImageCreateInfo(m_data.width, m_data.height, imageFormat);
709	const VkImageSubresourceRange		imageSubresourceRange				= makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
710	const de::MovePtr<ImageWithMemory>	image								= de::MovePtr<ImageWithMemory>(new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
711	const Move<VkImageView>				imageView							= makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_2D, imageFormat, imageSubresourceRange);
712
713	const VkBufferCreateInfo			resultBufferCreateInfo				= makeBufferCreateInfo(pixelCount*m_data.testConfiguration->getResultImageFormatSize(), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
714	const VkImageSubresourceLayers		resultBufferImageSubresourceLayers	= makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
715	const VkBufferImageCopy				resultBufferImageRegion				= makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, 1), resultBufferImageSubresourceLayers);
716	de::MovePtr<BufferWithMemory>		resultBuffer						= de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
717
718	const VkDescriptorImageInfo			descriptorImageInfo					= makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
719
720	const Move<VkCommandPool>			cmdPool								= createCommandPool(vkd, device, 0, queueFamilyIndex);
721	const Move<VkCommandBuffer>			cmdBuffer							= allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
722
723	std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >	bottomLevelAccelerationStructures;
724	de::MovePtr<TopLevelAccelerationStructure>						topLevelAccelerationStructure;
725
726	beginCommandBuffer(vkd, *cmdBuffer, 0u);
727	{
728		const VkImageMemoryBarrier			preImageBarrier						= makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT,
729																					VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
730																					**image, imageSubresourceRange);
731		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
732
733		const VkClearValue					clearValue							= m_data.testConfiguration->getClearValue();
734		vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1, &imageSubresourceRange);
735
736		const VkImageMemoryBarrier			postImageBarrier					= makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR,
737																					VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL,
738																					**image, imageSubresourceRange);
739		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, &postImageBarrier);
740
741		bottomLevelAccelerationStructures										= m_data.testConfiguration->initBottomAccelerationStructures(m_context, m_data);
742		for (auto& blas : bottomLevelAccelerationStructures)
743			blas->createAndBuild(vkd, device, *cmdBuffer, allocator);
744		topLevelAccelerationStructure											= m_data.testConfiguration->initTopAccelerationStructure(m_context, m_data, bottomLevelAccelerationStructures);
745		topLevelAccelerationStructure->createAndBuild(vkd, device, *cmdBuffer, allocator);
746
747		const TopLevelAccelerationStructure*			topLevelAccelerationStructurePtr		= topLevelAccelerationStructure.get();
748		VkWriteDescriptorSetAccelerationStructureKHR	accelerationStructureWriteDescriptorSet	=
749		{
750			VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,	//  VkStructureType						sType;
751			DE_NULL,															//  const void*							pNext;
752			1u,																	//  deUint32							accelerationStructureCount;
753			topLevelAccelerationStructurePtr->getPtr(),							//  const VkAccelerationStructureKHR*	pAccelerationStructures;
754		};
755
756		DescriptorSetUpdateBuilder()
757			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
758			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
759			.update(vkd, device);
760
761		vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet.get(), 0, DE_NULL);
762
763		vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
764
765		cmdTraceRays(vkd,
766			*cmdBuffer,
767			&raygenShaderBindingTableRegion,
768			&missShaderBindingTableRegion,
769			&hitShaderBindingTableRegion,
770			&callableShaderBindingTableRegion,
771			m_data.width, m_data.height, 1);
772
773		const VkMemoryBarrier							postTraceMemoryBarrier					= makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
774		const VkMemoryBarrier							postCopyMemoryBarrier					= makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
775		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
776
777		vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **resultBuffer, 1u, &resultBufferImageRegion);
778
779		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT, &postCopyMemoryBarrier);
780	}
781	endCommandBuffer(vkd, *cmdBuffer);
782
783	submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
784
785	invalidateMappedMemoryRange(vkd, device, resultBuffer->getAllocation().getMemory(), resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
786
787	return resultBuffer;
788}
789
790tcu::TestStatus CallableShaderTestInstance::iterate (void)
791{
792	// run test using arrays of pointers
793	const de::MovePtr<BufferWithMemory>	buffer		= runTest();
794
795	if (!m_data.testConfiguration->verifyImage(buffer.get(), m_context, m_data))
796		return tcu::TestStatus::fail("Fail");
797	return tcu::TestStatus::pass("Pass");
798}
799
800constexpr deUint32 callableDataUintLoc = 0;
801constexpr deUint32 callableDataFloatLoc = 1;
802constexpr deUint32 callableDataUintOutLoc = 2;
803
804struct CallableBuffer0
805{
806    deUint32 base;
807    deUint32 shift;
808    deUint32 offset;
809    deUint32 multiplier;
810};
811
812struct CallableBuffer1
813{
814	float numerator;
815	float denomenator;
816	deUint32 power;
817};
818
819struct Ray
820{
821	Ray() : o(0.0f), tmin(0.0f), d(0.0f), tmax(0.0f){}
822	Ray(const tcu::Vec3& io, float imin, const tcu::Vec3& id, float imax): o(io), tmin(imin), d(id), tmax(imax){}
823	tcu::Vec3 o;
824	float tmin;
825	tcu::Vec3 d;
826	float tmax;
827};
828
829constexpr float MAX_T_VALUE = 1000.0f;
830
831void AddVertexLayers(std::vector<tcu::Vec3>* pVerts, deUint32 newLayers)
832{
833	size_t vertsPerLayer = pVerts->size();
834	size_t totalLayers = 1 + newLayers;
835
836	pVerts->reserve(pVerts->size() * totalLayers);
837
838	for (size_t layer = 0; layer < newLayers; ++layer)
839	{
840		for (size_t vert = 0; vert < vertsPerLayer; ++vert)
841		{
842			bool flippedLayer = (layer % 2) == 0;
843			tcu::Vec3 stage = (*pVerts)[flippedLayer ? (vertsPerLayer - vert - 1) : vert];
844			++stage.z();
845
846			pVerts->push_back(stage);
847		}
848	}
849}
850
851bool compareFloat(float actual, float expected)
852{
853	constexpr float eps = 0.01f;
854	bool success = true;
855
856	if (abs(expected - actual) > eps)
857	{
858		success = false;
859	}
860
861	return success;
862}
863
864class InvokeCallableShaderTestCase : public TestCase
865{
866	public:
867							InvokeCallableShaderTestCase			(tcu::TestContext& context, const char* name, const TestParams& data);
868							~InvokeCallableShaderTestCase			(void);
869
870	virtual void			checkSupport								(Context& context) const;
871	virtual	void			initPrograms								(SourceCollections& programCollection) const;
872	virtual TestInstance*	createInstance								(Context& context) const;
873private:
874	TestParams			params;
875};
876
877class InvokeCallableShaderTestInstance : public TestInstance
878{
879public:
880																	InvokeCallableShaderTestInstance	(Context& context, const TestParams& data);
881																	~InvokeCallableShaderTestInstance	(void);
882	tcu::TestStatus													iterate									(void);
883
884private:
885	TestParams													params;
886};
887
888InvokeCallableShaderTestCase::InvokeCallableShaderTestCase (tcu::TestContext& context, const char* name, const TestParams& data)
889	: vkt::TestCase	(context, name)
890	, params		(data)
891{
892}
893
894InvokeCallableShaderTestCase::~InvokeCallableShaderTestCase (void)
895{
896}
897
898void InvokeCallableShaderTestCase::checkSupport (Context& context) const
899{
900	context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
901	context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
902
903	const VkPhysicalDeviceRayTracingPipelineFeaturesKHR&	rayTracingPipelineFeaturesKHR		= context.getRayTracingPipelineFeatures();
904	if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE )
905		TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
906
907	const VkPhysicalDeviceAccelerationStructureFeaturesKHR&	accelerationStructureFeaturesKHR	= context.getAccelerationStructureFeatures();
908	if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
909		TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
910}
911
912//resultData:
913// x - value0
914// y - value1
915// z - value2
916// w - closestT
917bool verifyResultData(const tcu::Vec4* resultData, deUint32 index, bool hit, const TestParams& params)
918{
919	bool success = true;
920
921	float refValue0 = 0.0f;
922	float refValue1 = 0.0f;
923	float refValue2 = 0.0f;
924
925	if (hit)
926	{
927		switch (params.invokingShader)
928		{
929		case glu::SHADERTYPE_RAYGEN :
930		case glu::SHADERTYPE_CLOSEST_HIT:
931		case glu::SHADERTYPE_CALLABLE:
932			refValue0 = 133.0f;
933			break;
934		case glu::SHADERTYPE_MISS:
935			break;
936		default:
937			TCU_THROW(InternalError, "Wrong shader invoking type");
938			break;
939		}
940
941		if (params.multipleInvocations)
942		{
943			switch (params.invokingShader)
944			{
945			case glu::SHADERTYPE_RAYGEN:
946				refValue1 = 17.64f;
947				refValue2 = 35.28f;
948				break;
949			case glu::SHADERTYPE_CLOSEST_HIT:
950				refValue1 = 17.64f;
951				refValue2 = index < 4 ? 35.28f : 8.82f;
952				break;
953			case glu::SHADERTYPE_CALLABLE:
954				refValue1 = 17.64f;
955				refValue2 = index < 6 ? 35.28f : 8.82f;
956				break;
957			case glu::SHADERTYPE_MISS:
958				break;
959			default:
960				TCU_THROW(InternalError, "Wrong shader invoking type");
961				break;
962			}
963		}
964
965		if (resultData->w() != 2.0f)
966		{
967			success = false;
968		}
969	}
970
971	if (!hit)
972	{
973		switch (params.invokingShader)
974		{
975		case glu::SHADERTYPE_RAYGEN:
976		case glu::SHADERTYPE_MISS:
977		case glu::SHADERTYPE_CALLABLE:
978			refValue0 = 133.0f;
979			break;
980		case glu::SHADERTYPE_CLOSEST_HIT:
981			break;
982		default:
983			TCU_THROW(InternalError, "Wrong shader invoking type");
984			break;
985		}
986
987		if (params.multipleInvocations)
988		{
989			switch (params.invokingShader)
990			{
991			case glu::SHADERTYPE_RAYGEN:
992				refValue1 = 17.64f;
993				refValue2 = 8.82f;
994				break;
995			case glu::SHADERTYPE_MISS:
996				refValue1 = 17.64f;
997				refValue2 = index < 10 ? 35.28f : 8.82f;
998				break;
999			case glu::SHADERTYPE_CALLABLE:
1000				refValue1 = 17.64f;
1001				refValue2 = index < 6 ? 35.28f : 8.82f;
1002				break;
1003			case glu::SHADERTYPE_CLOSEST_HIT:
1004				break;
1005			default:
1006				TCU_THROW(InternalError, "Wrong shader invoking type");
1007				break;
1008			}
1009		}
1010
1011		if (resultData->w() != MAX_T_VALUE)
1012		{
1013			success = false;
1014		}
1015	}
1016
1017	if ((!compareFloat(resultData->x(), refValue0)) ||
1018		(!compareFloat(resultData->y(), refValue1)) ||
1019		(!compareFloat(resultData->z(), refValue2)))
1020	{
1021		success = false;
1022	}
1023
1024	return success;
1025}
1026
1027std::string getRayGenSource(bool invokeCallable, bool multiInvoke)
1028{
1029	std::ostringstream src;
1030	src <<
1031		"struct Payload { uint lastShader; float closestT; };\n"
1032		"layout(location = 0) rayPayloadEXT Payload payload;\n";
1033
1034	if (invokeCallable)
1035	{
1036		src <<
1037			"#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc << "\n"
1038			"layout(location = CALLABLE_DATA_UINT_LOC) callableDataEXT uint callableDataUint;\n";
1039
1040		if (multiInvoke)
1041		{
1042			src <<
1043				"#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc << "\n"
1044				"layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1045		}
1046	}
1047
1048	src <<
1049		"void main() {\n"
1050		"   uint index = launchIndex();\n"
1051		"   Ray ray = rays[index];\n"
1052	    "   results[index].value0 = 0;\n"
1053		"   results[index].value1 = 0;\n"
1054		"   results[index].value2 = 0;\n";
1055
1056	if (invokeCallable)
1057	{
1058		src <<
1059			"   callableDataUint = " << "0" << ";\n"
1060			"   executeCallableEXT(0, CALLABLE_DATA_UINT_LOC);\n"
1061			"   results[index].value0 = float(callableDataUint);\n";
1062
1063		if (multiInvoke)
1064		{
1065			src <<
1066				"   callableDataFloat = 0.0;\n"
1067				"   executeCallableEXT(1, CALLABLE_DATA_FLOAT_LOC);\n"
1068				"   results[index].value1 = callableDataFloat;\n";
1069		}
1070	}
1071
1072	src <<
1073		"   payload.lastShader = " << glu::SHADERTYPE_RAYGEN << ";\n"
1074		"   payload.closestT = " << MAX_T_VALUE << ";\n"
1075		"   traceRayEXT(scene, 0x0, 0xff, 0, 0, 0, ray.pos, " << "ray.tmin" << ", ray.dir, ray.tmax, 0);\n";
1076
1077	if (invokeCallable && multiInvoke)
1078	{
1079		src <<
1080			"   executeCallableEXT(payload.lastShader == " << glu::SHADERTYPE_CLOSEST_HIT << " ? 1 : 2, CALLABLE_DATA_FLOAT_LOC);\n"
1081			"   results[index].value2 = callableDataFloat;\n";
1082	}
1083
1084	src <<
1085		"   results[index].closestT = payload.closestT;\n"
1086		"}";
1087
1088	return src.str();
1089}
1090
1091std::string getClosestHitSource(bool invokeCallable, bool multiInvoke)
1092{
1093	std::ostringstream src;
1094	src <<
1095		"struct Payload { uint lastShader; float closestT; };\n"
1096		"layout(location = 0) rayPayloadInEXT Payload payload;\n";
1097
1098	if (invokeCallable)
1099	{
1100		src <<
1101			"#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc << "\n"
1102			"layout(location = CALLABLE_DATA_UINT_LOC) callableDataEXT uint callableDataUint;\n";
1103
1104		if (multiInvoke)
1105		{
1106			src <<
1107				"#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc << "\n"
1108				"layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1109		}
1110	}
1111
1112	src <<
1113		"void main() {\n"
1114		"   payload.lastShader = " << glu::SHADERTYPE_CLOSEST_HIT << ";\n"
1115		"   payload.closestT = gl_HitTEXT;\n";
1116
1117	if (invokeCallable)
1118	{
1119		src <<
1120			"   uint index = launchIndex();\n"
1121			"   callableDataUint = 0;\n"
1122			"   executeCallableEXT(0, CALLABLE_DATA_UINT_LOC);\n"
1123			"   results[index].value0 = float(callableDataUint);\n";
1124
1125		if (multiInvoke)
1126		{
1127			src <<
1128				"   callableDataFloat = 0.0;\n"
1129				"   executeCallableEXT(1, CALLABLE_DATA_FLOAT_LOC);\n"
1130				"   results[index].value1 = callableDataFloat;\n"
1131				"   executeCallableEXT(index < 4 ? 1 : 2, CALLABLE_DATA_FLOAT_LOC);\n"
1132				"   results[index].value2 = callableDataFloat;\n";
1133		}
1134	}
1135
1136	src <<
1137		"}";
1138
1139	return src.str();
1140}
1141
1142std::string getMissSource(bool invokeCallable, bool multiInvoke)
1143{
1144	std::ostringstream src;
1145	src <<
1146		"struct Payload { uint lastShader; float closestT; };\n"
1147		"layout(location = 0) rayPayloadInEXT Payload payload;\n";
1148
1149	if (invokeCallable)
1150	{
1151		src <<
1152			"#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc << "\n"
1153			"layout(location = CALLABLE_DATA_UINT_LOC) callableDataEXT uint callableDataUint;\n";
1154
1155		if (multiInvoke)
1156		{
1157			src <<
1158				"#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc << "\n"
1159				"layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1160		}
1161	}
1162
1163	src <<
1164		"void main() {\n"
1165		"   payload.lastShader = " << glu::SHADERTYPE_MISS << ";\n";
1166
1167	if (invokeCallable)
1168	{
1169		src <<
1170			"   uint index = launchIndex();\n"
1171			"   callableDataUint = 0;\n"
1172			"   executeCallableEXT(0, CALLABLE_DATA_UINT_LOC);\n"
1173			"   results[index].value0 = float(callableDataUint);\n";
1174
1175		if (multiInvoke)
1176		{
1177			src <<
1178				"   callableDataFloat = 0.0;\n"
1179				"   executeCallableEXT(1, CALLABLE_DATA_FLOAT_LOC);\n"
1180				"   results[index].value1 = callableDataFloat;\n"
1181				"   executeCallableEXT(index < 10 ? 1 : 2, CALLABLE_DATA_FLOAT_LOC);\n"
1182				"   results[index].value2 = callableDataFloat;\n";
1183		}
1184	}
1185
1186	src <<
1187		"}";
1188
1189	return src.str();
1190}
1191
1192std::string getCallableSource(bool invokeCallable, bool multiInvoke)
1193{
1194	std::ostringstream src;
1195	src <<
1196		"#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc << "\n"
1197		"layout(location = CALLABLE_DATA_UINT_LOC) callableDataInEXT uint callableDataUintIn;\n";
1198
1199	if (invokeCallable)
1200	{
1201		src << "#define CALLABLE_DATA_UINT_OUT_LOC " << callableDataUintOutLoc << "\n"
1202			<< "layout(location = CALLABLE_DATA_UINT_OUT_LOC) callableDataEXT uint callableDataUint;\n";
1203
1204		if (multiInvoke)
1205		{
1206			src <<
1207				"#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc << "\n"
1208				"layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataEXT float callableDataFloat;\n";
1209		}
1210	}
1211
1212	src <<
1213		"void main() {\n";
1214
1215	if (invokeCallable)
1216	{
1217		src <<
1218			"   uint index = launchIndex();\n"
1219			"   callableDataUint = 0;\n"
1220			"   executeCallableEXT(1, CALLABLE_DATA_UINT_OUT_LOC);\n"
1221			"   callableDataUintIn = callableDataUint;\n";
1222
1223		if (multiInvoke)
1224		{
1225			src <<
1226				"   callableDataFloat = 0.0;\n"
1227				"   executeCallableEXT(2, CALLABLE_DATA_FLOAT_LOC);\n"
1228				"   results[index].value1 = callableDataFloat;\n"
1229				"   executeCallableEXT(index < 6 ? 2 : 3, CALLABLE_DATA_FLOAT_LOC);\n"
1230				"   results[index].value2 = callableDataFloat;\n";
1231		}
1232	}
1233
1234	src <<
1235		"}";
1236
1237	return src.str();
1238}
1239
1240constexpr deUint32 DefaultResultBinding = 0;
1241constexpr deUint32 DefaultSceneBinding = 1;
1242constexpr deUint32 DefaultRaysBinding = 2;
1243
1244enum ShaderSourceFlag
1245{
1246	DEFINE_RAY = 0x1,
1247	DEFINE_RESULT_BUFFER = 0x2,
1248	DEFINE_SCENE = 0x4,
1249	DEFINE_RAY_BUFFER = 0x8,
1250	DEFINE_SIMPLE_BINDINGS = DEFINE_RESULT_BUFFER | DEFINE_SCENE | DEFINE_RAY_BUFFER
1251};
1252
1253static inline std::string generateShaderSource(const char* body, const char* resultType = "", deUint32 flags = 0, const char* prefix = "")
1254{
1255	std::ostringstream src;
1256	src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n";
1257
1258	src << "#extension GL_EXT_ray_tracing : enable\n";
1259
1260	src << prefix << "\n";
1261
1262	if (flags & DEFINE_SIMPLE_BINDINGS)
1263		flags |= DEFINE_RAY_BUFFER;
1264
1265	if (flags & DEFINE_RAY_BUFFER)
1266		flags |= DEFINE_RAY;
1267
1268	if (flags & DEFINE_RAY)
1269	{
1270		src << "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n";
1271	}
1272
1273	if (flags & DEFINE_RESULT_BUFFER)
1274		src << "layout(std430, set = 0, binding = " << DefaultResultBinding << ") buffer Results { " << resultType << " results[]; };\n";
1275
1276	if (flags & DEFINE_SCENE)
1277	{
1278		src << "layout(set = 0, binding = " << DefaultSceneBinding << ") uniform accelerationStructureEXT scene;\n";
1279	}
1280
1281	if (flags & DEFINE_RAY_BUFFER)
1282		src << "layout(std430, set = 0, binding = " << DefaultRaysBinding << ") buffer Rays { Ray rays[]; };\n";
1283
1284	src << "uint launchIndex() { return gl_LaunchIDEXT.z*gl_LaunchSizeEXT.x*gl_LaunchSizeEXT.y + gl_LaunchIDEXT.y*gl_LaunchSizeEXT.x + gl_LaunchIDEXT.x; }\n";
1285
1286	src << body;
1287
1288	return src.str();
1289}
1290
1291template<typename T> inline void addShaderSource(SourceCollections& programCollection, const char* identifier,
1292											const char* body, const char* resultType = "", deUint32 flags = 0,
1293                                            const char* prefix = "", deUint32 validatorOptions = 0U)
1294{
1295	std::string text = generateShaderSource(body, resultType, flags, prefix);
1296
1297	const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, validatorOptions, true);
1298	programCollection.glslSources.add(identifier) << T(text) << buildOptions;
1299}
1300
1301
1302
1303void InvokeCallableShaderTestCase::initPrograms (SourceCollections& programCollection) const
1304{
1305	addShaderSource<glu::RaygenSource>(programCollection, "build-raygen",
1306		getRayGenSource(false, false).c_str(), "Result", DEFINE_RAY_BUFFER | DEFINE_SIMPLE_BINDINGS,
1307		"struct Result { float value0; float value1; float value2; float closestT;};");
1308
1309	addShaderSource<glu::RaygenSource>(programCollection, "build-raygen-invoke-callable",
1310		getRayGenSource(true, false).c_str(), "Result", DEFINE_RAY_BUFFER | DEFINE_SIMPLE_BINDINGS,
1311		"struct Result { float value0; float value1; float value2; float closestT;};");
1312
1313	addShaderSource<glu::ClosestHitSource>(programCollection, "build-closesthit",
1314		getClosestHitSource(false, false).c_str(), "", 0, "");
1315
1316	addShaderSource<glu::MissSource>(programCollection, "build-miss",
1317		getMissSource(false, false).c_str(), "", 0, "");
1318
1319	const std::string RAY_PAYLOAD		    = "rayPayloadEXT";
1320	const std::string TRACE_RAY			    = "traceRayEXT";
1321	const std::string RAY_PAYLOAD_IN	    = "rayPayloadInEXT";
1322	const std::string HIT_ATTRIBUTE		    = "hitAttributeEXT";
1323    const std::string REPORT_INTERSECTION   = "reportIntersectionEXT";
1324
1325    const std::string SHADER_RECORD         = "shaderRecordEXT";
1326	const std::string CALLABLE_DATA_IN	    = "callableDataInEXT";
1327	const std::string CALLABLE_DATA		    = "callableDataEXT";
1328	const std::string EXECUTE_CALLABE	    = "executeCallableEXT";
1329
1330	std::ostringstream src;
1331    src <<
1332		"#define CALLABLE_DATA_UINT_LOC " << callableDataUintLoc << "\n"
1333		"layout(location = CALLABLE_DATA_UINT_LOC) callableDataInEXT uint callableDataUint;\n"
1334        "layout(" << SHADER_RECORD << ") buffer callableBuffer\n"
1335        "{\n"
1336        "   uint base;\n"
1337        "   uint shift;\n"
1338        "   uint offset;\n"
1339        "   uint multiplier;\n"
1340        "};\n"
1341        "void main() {\n"
1342        "   callableDataUint += ((base << shift) + offset) * multiplier;\n"
1343        "}";
1344
1345    addShaderSource<glu::CallableSource>(programCollection, "build-callable-0", src.str().c_str(),
1346        "", 0, "");
1347
1348	if (params.multipleInvocations)
1349	{
1350		switch (params.invokingShader)
1351		{
1352		case glu::SHADERTYPE_RAYGEN:
1353			addShaderSource<glu::RaygenSource>(programCollection, "build-raygen-invoke-callable-multi",
1354				getRayGenSource(true, true).c_str(), "Result", DEFINE_RAY_BUFFER | DEFINE_SIMPLE_BINDINGS,
1355				"struct Result { float value0; float value1; float value2; float closestT;};");
1356
1357			break;
1358		case glu::SHADERTYPE_CLOSEST_HIT:
1359			addShaderSource<glu::ClosestHitSource>(programCollection, "build-closesthit-invoke-callable-multi",
1360				getClosestHitSource(true, true).c_str(), "Result", DEFINE_RESULT_BUFFER,
1361				"struct Result { float value0; float value1; float value2; float closestT;};");
1362
1363			break;
1364		case glu::SHADERTYPE_MISS:
1365			addShaderSource<glu::MissSource>(programCollection, "build-miss-invoke-callable-multi",
1366				getMissSource(true, true).c_str(), "Result", DEFINE_RESULT_BUFFER,
1367				"struct Result { float value0; float value1; float value2; float closestT;};");
1368
1369			break;
1370		case glu::SHADERTYPE_CALLABLE:
1371			addShaderSource<glu::CallableSource>(programCollection, "build-callable-invoke-callable-multi",
1372				getCallableSource(true, true).c_str(), "Result", DEFINE_RESULT_BUFFER,
1373				"struct Result { float value0; float value1; float value2; float closestT;};");
1374
1375			break;
1376		default:
1377			TCU_THROW(InternalError, "Wrong shader invoking type");
1378			break;
1379		}
1380
1381		src.str(std::string());
1382		src <<
1383			"#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc << "\n"
1384			"layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataInEXT float callableDataFloat;\n"
1385			"layout(" << SHADER_RECORD << ") buffer callableBuffer\n"
1386			"{\n"
1387			"   float numerator;\n"
1388			"   float denomenator;\n"
1389			"   uint power;\n"
1390			"   uint reserved;\n"
1391			"};\n"
1392			"void main() {\n"
1393			"   float base = numerator / denomenator;\n"
1394			"   float result = 1;\n"
1395			"   for (uint i = 0; i < power; ++i)\n"
1396			"   {\n"
1397			"      result *= base;\n"
1398			"   }\n"
1399			"   callableDataFloat += result;\n"
1400			"}";
1401
1402		addShaderSource<glu::CallableSource>(programCollection, "build-callable-1", src.str().c_str(),
1403            "", 0, "");
1404
1405		src.str(std::string());
1406		src <<
1407			"#define CALLABLE_DATA_FLOAT_LOC " << callableDataFloatLoc << "\n"
1408			"layout(location = CALLABLE_DATA_FLOAT_LOC) callableDataInEXT float callableDataFloat;\n"
1409			"void main() {\n"
1410			"   callableDataFloat /= 2.0f;\n"
1411			"}";
1412
1413		addShaderSource<glu::CallableSource>(programCollection, "build-callable-2", src.str().c_str(),
1414            "", 0, "");
1415	}
1416	else
1417	{
1418		switch (params.invokingShader)
1419		{
1420		case glu::SHADERTYPE_RAYGEN:
1421			// Always defined since it's needed to invoke callable shaders that invoke other callable shaders
1422
1423			break;
1424		case glu::SHADERTYPE_CLOSEST_HIT:
1425			addShaderSource<glu::ClosestHitSource>(programCollection, "build-closesthit-invoke-callable",
1426				getClosestHitSource(true, false).c_str(), "Result", DEFINE_RESULT_BUFFER,
1427				"struct Result { float value0; float value1; float value2; float closestT;};");
1428
1429			break;
1430		case glu::SHADERTYPE_MISS:
1431			addShaderSource<glu::MissSource>(programCollection, "build-miss-invoke-callable",
1432				getMissSource(true, false).c_str(), "Result", DEFINE_RESULT_BUFFER,
1433				"struct Result { float value0; float value1; float value2; float closestT;};");
1434
1435			break;
1436		case glu::SHADERTYPE_CALLABLE:
1437			addShaderSource<glu::CallableSource>(programCollection, "build-callable-invoke-callable",
1438				getCallableSource(true, false).c_str(), "Result", DEFINE_RESULT_BUFFER,
1439				"struct Result { float value0; float value1; float value2; float closestT;};");
1440
1441			break;
1442		default:
1443			TCU_THROW(InternalError, "Wrong shader invoking type");
1444			break;
1445		}
1446	}
1447}
1448
1449TestInstance* InvokeCallableShaderTestCase::createInstance (Context& context) const
1450{
1451	return new InvokeCallableShaderTestInstance(context, params);
1452}
1453
1454InvokeCallableShaderTestInstance::InvokeCallableShaderTestInstance (Context& context, const TestParams& data)
1455	: vkt::TestInstance		(context)
1456	, params				(data)
1457{
1458}
1459
1460InvokeCallableShaderTestInstance::~InvokeCallableShaderTestInstance (void)
1461{
1462}
1463
1464tcu::TestStatus InvokeCallableShaderTestInstance::iterate()
1465{
1466	const VkDevice device = m_context.getDevice();
1467	const DeviceInterface& vk = m_context.getDeviceInterface();
1468	const InstanceInterface& vki = m_context.getInstanceInterface();
1469	Allocator& allocator = m_context.getDefaultAllocator();
1470	de::MovePtr<RayTracingProperties> rayTracingProperties = makeRayTracingProperties(vki, m_context.getPhysicalDevice());
1471
1472	vk::Move<VkDescriptorPool>		descriptorPool;
1473	vk::Move<VkDescriptorSetLayout> descriptorSetLayout;
1474	std::vector<vk::Move<VkDescriptorSet>>		descriptorSet;
1475	vk::Move<VkPipelineLayout>		pipelineLayout;
1476
1477	vk::DescriptorPoolBuilder descriptorPoolBuilder;
1478
1479	deUint32 storageBufCount = 0;
1480
1481	const VkDescriptorType accelType = VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR;
1482
1483	storageBufCount += 1;
1484
1485	storageBufCount += 1;
1486
1487	descriptorPoolBuilder.addType(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, storageBufCount);
1488
1489	descriptorPoolBuilder.addType(accelType, 1);
1490
1491	descriptorPool = descriptorPoolBuilder.build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1);
1492
1493	vk::DescriptorSetLayoutBuilder setLayoutBuilder;
1494
1495	const deUint32 AllStages = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
1496							   VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
1497							   VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
1498
1499	setLayoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, AllStages);
1500	setLayoutBuilder.addSingleBinding(accelType, AllStages);
1501	setLayoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, AllStages);
1502
1503	descriptorSetLayout = setLayoutBuilder.build(vk, device);
1504
1505	const VkDescriptorSetAllocateInfo descriptorSetAllocateInfo =
1506	{
1507		VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO,		// VkStructureType				sType;
1508		DE_NULL,											// const void*					pNext;
1509		*descriptorPool,									// VkDescriptorPool				descriptorPool;
1510		1u,													// deUint32						setLayoutCount;
1511		&descriptorSetLayout.get()							// const VkDescriptorSetLayout*	pSetLayouts;
1512	};
1513
1514	descriptorSet.push_back(allocateDescriptorSet(vk, device, &descriptorSetAllocateInfo));
1515
1516	const VkPipelineLayoutCreateInfo pipelineLayoutInfo =
1517	{
1518		VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,		// VkStructureType				sType;
1519		DE_NULL,											// const void*					pNext;
1520		(VkPipelineLayoutCreateFlags)0,						// VkPipelineLayoutCreateFlags	flags;
1521		1u,													// deUint32						setLayoutCount;
1522		&descriptorSetLayout.get(),							// const VkDescriptorSetLayout*	pSetLayouts;
1523		0u,													// deUint32						pushConstantRangeCount;
1524		nullptr,											// const VkPushConstantRange*	pPushConstantRanges;
1525	};
1526
1527	pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutInfo);
1528
1529	std::string raygenId = "build-raygen";
1530	std::string missId = "build-miss";
1531	std::string closestHitId = "build-closesthit";
1532	std::vector<std::string> callableIds;
1533
1534	switch (params.invokingShader)
1535	{
1536	case glu::SHADERTYPE_RAYGEN:
1537	{
1538		raygenId.append("-invoke-callable");
1539
1540		if (params.multipleInvocations)
1541		{
1542			raygenId.append("-multi");
1543		}
1544		break;
1545	}
1546	case glu::SHADERTYPE_MISS:
1547	{
1548		missId.append("-invoke-callable");
1549
1550		if (params.multipleInvocations)
1551		{
1552			missId.append("-multi");
1553		}
1554		break;
1555	}
1556	case glu::SHADERTYPE_CLOSEST_HIT:
1557	{
1558		closestHitId.append("-invoke-callable");
1559
1560		if (params.multipleInvocations)
1561		{
1562			closestHitId.append("-multi");
1563		}
1564		break;
1565	}
1566	case glu::SHADERTYPE_CALLABLE:
1567	{
1568		raygenId.append("-invoke-callable");
1569		std::string callableId("build-callable-invoke-callable");
1570
1571		if (params.multipleInvocations)
1572		{
1573			callableId.append("-multi");
1574		}
1575
1576		callableIds.push_back(callableId);
1577		break;
1578	}
1579	default:
1580	{
1581		TCU_THROW(InternalError, "Wrong shader invoking type");
1582		break;
1583	}
1584	}
1585
1586	callableIds.push_back("build-callable-0");
1587	if (params.multipleInvocations)
1588	{
1589		callableIds.push_back("build-callable-1");
1590		callableIds.push_back("build-callable-2");
1591	}
1592
1593	de::MovePtr<RayTracingPipeline>	rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
1594	rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vk, device, m_context.getBinaryCollection().get(raygenId.c_str()), 0), 0);
1595	rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vk, device, m_context.getBinaryCollection().get(missId.c_str()), 0), 1);
1596	rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vk, device, m_context.getBinaryCollection().get(closestHitId.c_str()), 0), 2);
1597	deUint32 callableGroup = 3;
1598	for (auto& callableId : callableIds)
1599	{
1600		rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,	createShaderModule(vk, device, m_context.getBinaryCollection().get(callableId.c_str()), 0), callableGroup);
1601		++callableGroup;
1602	}
1603	Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vk, device, *pipelineLayout);
1604
1605	CallableBuffer0 callableBuffer0 = { 1, 4, 3, 7 };
1606	CallableBuffer1 callableBuffer1 = { 10.5, 2.5, 2 };
1607
1608	size_t MaxBufferSize = std::max(sizeof(callableBuffer0), sizeof(callableBuffer1));
1609	deUint32 shaderGroupHandleSize = rayTracingProperties->getShaderGroupHandleSize();
1610	deUint32 shaderGroupBaseAlignment = rayTracingProperties->getShaderGroupBaseAlignment();
1611	size_t shaderStride = deAlign32(shaderGroupHandleSize + (deUint32)MaxBufferSize, shaderGroupHandleSize);
1612
1613	de::MovePtr<BufferWithMemory> raygenShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
1614	de::MovePtr<BufferWithMemory> missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
1615	de::MovePtr<BufferWithMemory> hitShaderBindingTable					= rayTracingPipeline->createShaderBindingTable(vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
1616	de::MovePtr<BufferWithMemory> callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vk, device, *pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, (deUint32)callableIds.size(),
1617																													   0U, 0U, MemoryRequirement::Any, 0U, 0U, (deUint32)MaxBufferSize, nullptr, true);
1618
1619	VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vk, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
1620	VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vk, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
1621	VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vk, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
1622	VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vk, device, callableShaderBindingTable->get(), 0), shaderStride, shaderGroupHandleSize);
1623
1624	size_t callableCount = 0;
1625
1626	if (params.invokingShader == glu::SHADERTYPE_CALLABLE)
1627	{
1628		callableCount++;
1629	}
1630
1631	deMemcpy((deUint8*)callableShaderBindingTable->getAllocation().getHostPtr() +
1632		(shaderStride * (callableCount)) + shaderGroupHandleSize,
1633		&callableBuffer0,
1634		sizeof(CallableBuffer0));
1635	callableCount++;
1636
1637	if (params.multipleInvocations)
1638	{
1639		deMemcpy((deUint8*)callableShaderBindingTable->getAllocation().getHostPtr() +
1640			(shaderStride * (callableCount)) + shaderGroupHandleSize,
1641			&callableBuffer1,
1642			sizeof(CallableBuffer1));
1643		callableCount++;
1644	}
1645
1646	flushMappedMemoryRange(vk, device, callableShaderBindingTable->getAllocation().getMemory(), callableShaderBindingTable->getAllocation().getOffset(), VK_WHOLE_SIZE);
1647
1648	//                 {I}
1649	// (-2,1) (-1,1)  (0,1)  (1,1)  (2,1)
1650	//    X------X------X------X------X
1651	//    |\     |\     |\     |\     |
1652	//    | \ {B}| \ {D}| \ {F}| \ {H}|
1653	// {K}|  \   |  \   |  \   |  \   |{L}
1654	//    |   \  |   \  |   \  |   \  |
1655	//    |{A} \ |{C} \ |{E} \ |{G} \ |
1656	//    |     \|     \|     \|     \|
1657	//    X------X------X------X------X
1658	// (-2,-1)(-1,-1) (0,-1) (1,-1) (2,-1)
1659	//                 {J}
1660	//
1661	// A, B, E, and F are initially opaque
1662	// A and C are forced opaque
1663	// E and G are forced non-opaque
1664
1665	std::vector<Ray> rays = {
1666		Ray{ tcu::Vec3(-1.67f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {A}
1667		Ray{ tcu::Vec3(-1.33f,  0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {B}
1668		Ray{ tcu::Vec3(-0.67f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {C}
1669		Ray{ tcu::Vec3(-0.33f,  0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {D}
1670		Ray{ tcu::Vec3( 0.33f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {E}
1671		Ray{ tcu::Vec3( 0.67f,  0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {F}
1672		Ray{ tcu::Vec3( 1.33f, -0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {G}
1673		Ray{ tcu::Vec3( 1.67f,  0.33f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {H}
1674		Ray{ tcu::Vec3( 0.0f,   1.01f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {I}
1675		Ray{ tcu::Vec3( 0.0f,  -1.01f, 0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {J}
1676		Ray{ tcu::Vec3(-2.01f,  0.0f,  0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }, // {K}
1677		Ray{ tcu::Vec3( 2.01f,  0.0f,  0.0f), 0.0f, tcu::Vec3(0.0f, 0.0f, 1.0f), MAX_T_VALUE }  // {L}
1678	};
1679
1680	// B & F
1681	std::vector<tcu::Vec3> blas0VertsOpaque = {
1682		{ -2.0f,  1.0f, 2.0f },
1683		{ -1.0f, -1.0f, 2.0f },
1684		{ -1.0f,  1.0f, 2.0f },
1685		{  0.0f,  1.0f, 2.0f },
1686		{  1.0f, -1.0f, 2.0f },
1687		{  1.0f,  1.0f, 2.0f }
1688	};
1689
1690	// D & H
1691	std::vector<tcu::Vec3> blas0VertsNoOpaque = {
1692		{ -1.0f,  1.0f, 2.0f },
1693		{  0.0f, -1.0f, 2.0f },
1694		{  0.0f,  1.0f, 2.0f },
1695		{  1.0f,  1.0f, 2.0f },
1696		{  2.0f, -1.0f, 2.0f },
1697		{  2.0f,  1.0f, 2.0f }
1698	};
1699
1700	// A
1701	std::vector<tcu::Vec3> blas1VertsOpaque = {
1702		{ -2.0f,  1.0f, 2.0f },
1703		{ -2.0f, -1.0f, 2.0f },
1704		{ -1.0f, -1.0f, 2.0f }
1705	};
1706
1707	// C
1708	std::vector<tcu::Vec3> blas1VertsNoOpaque = {
1709		{ -1.0f,  1.0f, 2.0f },
1710		{ -1.0f, -1.0f, 2.0f },
1711		{  0.0f, -1.0f, 2.0f }
1712	};
1713
1714	// E
1715	std::vector<tcu::Vec3> blas2VertsOpaque = {
1716		{  0.0f,  1.0f, 2.0f },
1717		{  0.0f, -1.0f, 2.0f },
1718		{  1.0f, -1.0f, 2.0f }
1719	};
1720
1721	// G
1722	std::vector<tcu::Vec3> blas2VertsNoOpaque = {
1723		{  1.0f,  1.0f, 2.0f },
1724		{  1.0f, -1.0f, 2.0f },
1725		{  2.0f, -1.0f, 2.0f }
1726	};
1727
1728	AddVertexLayers(&blas0VertsOpaque, 1);
1729	AddVertexLayers(&blas0VertsNoOpaque, 1);
1730	AddVertexLayers(&blas1VertsOpaque, 1);
1731	AddVertexLayers(&blas1VertsNoOpaque, 1);
1732	AddVertexLayers(&blas2VertsOpaque, 1);
1733	AddVertexLayers(&blas2VertsNoOpaque, 1);
1734
1735	std::vector<tcu::Vec3> verts;
1736	verts.reserve(
1737		blas0VertsOpaque.size() + blas0VertsNoOpaque.size() +
1738		blas1VertsOpaque.size() + blas1VertsNoOpaque.size() +
1739		blas2VertsOpaque.size() + blas2VertsNoOpaque.size());
1740	verts.insert(verts.end(), blas0VertsOpaque.begin(), blas0VertsOpaque.end());
1741	verts.insert(verts.end(), blas0VertsNoOpaque.begin(), blas0VertsNoOpaque.end());
1742	verts.insert(verts.end(), blas1VertsOpaque.begin(), blas1VertsOpaque.end());
1743	verts.insert(verts.end(), blas1VertsNoOpaque.begin(), blas1VertsNoOpaque.end());
1744	verts.insert(verts.end(), blas2VertsOpaque.begin(), blas2VertsOpaque.end());
1745	verts.insert(verts.end(), blas2VertsNoOpaque.begin(), blas2VertsNoOpaque.end());
1746
1747	tcu::Surface resultImage(static_cast<int>(rays.size()), 1);
1748
1749	const VkBufferCreateInfo			resultBufferCreateInfo			= makeBufferCreateInfo(rays.size() * sizeof(tcu::Vec4), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1750	de::MovePtr<BufferWithMemory>		resultBuffer					= de::MovePtr<BufferWithMemory>(new BufferWithMemory(vk, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
1751	const VkDescriptorBufferInfo		resultDescriptorInfo			= makeDescriptorBufferInfo(resultBuffer->get(), 0, VK_WHOLE_SIZE);
1752
1753	const VkBufferCreateInfo			rayBufferCreateInfo				= makeBufferCreateInfo(rays.size() * sizeof(Ray), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1754	de::MovePtr<BufferWithMemory>		rayBuffer						= de::MovePtr<BufferWithMemory>(new BufferWithMemory(vk, device, allocator, rayBufferCreateInfo, MemoryRequirement::HostVisible));
1755	const VkDescriptorBufferInfo		rayDescriptorInfo				= makeDescriptorBufferInfo(rayBuffer->get(), 0, VK_WHOLE_SIZE);
1756	memcpy(rayBuffer->getAllocation().getHostPtr(), &rays[0], rays.size() * sizeof(Ray));
1757	flushMappedMemoryRange(vk, device, rayBuffer->getAllocation().getMemory(), rayBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
1758
1759	const Move<VkCommandPool>			cmdPool							= createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
1760	const Move<VkCommandBuffer>			cmdBuffer						= allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1761
1762	beginCommandBuffer(vk, *cmdBuffer);
1763
1764	de::SharedPtr<BottomLevelAccelerationStructure>	blas0 = de::SharedPtr<BottomLevelAccelerationStructure>(makeBottomLevelAccelerationStructure().release());
1765	blas0->setGeometryCount(2);
1766	blas0->addGeometry(blas0VertsOpaque, true, VK_GEOMETRY_OPAQUE_BIT_KHR);
1767	blas0->addGeometry(blas0VertsNoOpaque, true, 0U);
1768	blas0->createAndBuild(vk, device, *cmdBuffer, allocator);
1769
1770	de::SharedPtr<BottomLevelAccelerationStructure>	blas1 = de::SharedPtr<BottomLevelAccelerationStructure>(makeBottomLevelAccelerationStructure().release());
1771	blas1->setGeometryCount(2);
1772	blas1->addGeometry(blas1VertsOpaque, true, VK_GEOMETRY_OPAQUE_BIT_KHR);
1773	blas1->addGeometry(blas1VertsNoOpaque, true, 0U);
1774	blas1->createAndBuild(vk, device, *cmdBuffer, allocator);
1775
1776	de::SharedPtr<BottomLevelAccelerationStructure>	blas2 = de::SharedPtr<BottomLevelAccelerationStructure>(makeBottomLevelAccelerationStructure().release());
1777	blas2->setGeometryCount(2);
1778	blas2->addGeometry(blas2VertsOpaque, true, VK_GEOMETRY_OPAQUE_BIT_KHR);
1779	blas2->addGeometry(blas2VertsNoOpaque, true, 0U);
1780	blas2->createAndBuild(vk, device, *cmdBuffer, allocator);
1781
1782	de::MovePtr<TopLevelAccelerationStructure>	tlas	= makeTopLevelAccelerationStructure();
1783	tlas->setInstanceCount(3);
1784	tlas->addInstance(blas0);
1785	tlas->addInstance(blas1);
1786	tlas->addInstance(blas2);
1787	tlas->createAndBuild(vk, device, *cmdBuffer, allocator);
1788
1789	VkWriteDescriptorSetAccelerationStructureKHR	accelerationStructureWriteDescriptorSet	=
1790	{
1791		VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,	//  VkStructureType						sType;
1792		DE_NULL,															//  const void*							pNext;
1793		1u,																	//  deUint32							accelerationStructureCount;
1794		tlas->getPtr(),														//  const VkAccelerationStructureKHR*	pAccelerationStructures;
1795	};
1796
1797	DescriptorSetUpdateBuilder()
1798		.writeSingle(*descriptorSet[0], DescriptorSetUpdateBuilder::Location::binding(DefaultResultBinding), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &resultDescriptorInfo)
1799		.writeSingle(*descriptorSet[0], DescriptorSetUpdateBuilder::Location::binding(DefaultSceneBinding), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
1800		.writeSingle(*descriptorSet[0], DescriptorSetUpdateBuilder::Location::binding(DefaultRaysBinding), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &rayDescriptorInfo)
1801		.update(vk, device);
1802
1803	vk.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
1804	vk.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet[0].get(), 0, DE_NULL);
1805
1806	cmdTraceRays(vk,
1807		*cmdBuffer,
1808		&raygenShaderBindingTableRegion,
1809		&missShaderBindingTableRegion,
1810		&hitShaderBindingTableRegion,
1811		&callableShaderBindingTableRegion,
1812		static_cast<deUint32>(rays.size()), 1, 1);
1813
1814	endCommandBuffer(vk, *cmdBuffer);
1815
1816	submitCommandsAndWait(vk, device, m_context.getUniversalQueue(), *cmdBuffer);
1817
1818	invalidateMappedMemoryRange(vk, device, resultBuffer->getAllocation().getMemory(), resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
1819
1820	//                 {I}
1821	// (-2,1) (-1,1)  (0,1)  (1,1)  (2,1)
1822	//    X------X------X------X------X
1823	//    |\     |\     |\     |\     |
1824	//    | \ {B}| \ {D}| \ {F}| \ {H}|
1825	// {K}|  \   |  \   |  \   |  \   |{L}
1826	//    |   \  |   \  |   \  |   \  |
1827	//    |{A} \ |{C} \ |{E} \ |{G} \ |
1828	//    |     \|     \|     \|     \|
1829	//    X------X------X------X------X
1830	// (-2,-1)(-1,-1) (0,-1) (1,-1) (2,-1)
1831	//                 {J}
1832	// A, B, E, and F are opaque
1833	// A and C are forced opaque
1834	// E and G are forced non-opaque
1835
1836	std::vector<bool> hits = { true, true, true, true, true, true, true, true, false, false, false, false };
1837	std::vector<bool> opaques = { true, true, true, false, false, true, false, false, true, true, true, true };
1838
1839
1840	union
1841	{
1842		bool     mismatch[32];
1843		deUint32 mismatchAll;
1844	};
1845	mismatchAll = 0;
1846
1847	tcu::Vec4* resultData = (tcu::Vec4*)resultBuffer->getAllocation().getHostPtr();
1848
1849	for (int index = 0; index < resultImage.getWidth(); ++index)
1850	{
1851		if (verifyResultData(&resultData[index], index, hits[index], params))
1852		{
1853			resultImage.setPixel(index, 0, tcu::RGBA(255, 0, 0, 255));
1854		}
1855		else
1856		{
1857			mismatch[index] = true;
1858			resultImage.setPixel(index, 0, tcu::RGBA(0, 0, 0, 255));
1859		}
1860	}
1861
1862	// Write Image
1863	m_context.getTestContext().getLog() << tcu::TestLog::ImageSet("Result of rendering", "Result of rendering")
1864	<< tcu::TestLog::Image("Result", "Result", resultImage)
1865	<< tcu::TestLog::EndImageSet;
1866
1867	if (mismatchAll != 0)
1868		TCU_FAIL("Result data did not match expected output");
1869
1870	return tcu::TestStatus::pass("pass");
1871}
1872
1873}	// anonymous
1874
1875tcu::TestCaseGroup*	createCallableShadersTests (tcu::TestContext& testCtx)
1876{
1877	// Tests veryfying callable shaders
1878	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "callable_shader"));
1879
1880	struct CallableShaderTestTypeData
1881	{
1882		CallableShaderTestType					shaderTestType;
1883		const char*								name;
1884	} callableShaderTestTypes[] =
1885	{
1886		{ CSTT_RGEN_CALL,		"rgen_call"			},
1887		{ CSTT_RGEN_CALL_CALL,	"rgen_call_call"	},
1888		{ CSTT_HIT_CALL,		"hit_call"			},
1889		{ CSTT_RGEN_MULTICALL,	"rgen_multicall"	},
1890	};
1891
1892	for (size_t shaderTestNdx = 0; shaderTestNdx < DE_LENGTH_OF_ARRAY(callableShaderTestTypes); ++shaderTestNdx)
1893	{
1894		TestParams testParams
1895		{
1896			TEST_WIDTH,
1897			TEST_HEIGHT,
1898			callableShaderTestTypes[shaderTestNdx].shaderTestType,
1899			de::SharedPtr<TestConfiguration>(new SingleSquareConfiguration()),
1900			glu::SHADERTYPE_LAST,
1901			false
1902		};
1903		group->addChild(new CallableShaderTestCase(group->getTestContext(), callableShaderTestTypes[shaderTestNdx].name, testParams));
1904	}
1905
1906    bool                multipleInvocations[]     = { false               , true };
1907    std::string         multipleInvocationsText[] = { "_single_invocation", "_multiple_invocations" };
1908    // Callable shaders cannot be called from any-hit shader per GLSL_NV_ray_tracing spec. Assuming same will hold for KHR version.
1909    glu::ShaderType     invokingShaders[]         = { glu::SHADERTYPE_RAYGEN, glu::SHADERTYPE_CALLABLE, glu::SHADERTYPE_CLOSEST_HIT, glu::SHADERTYPE_MISS };
1910    std::string         invokingShadersText[]     = { "_invoked_via_raygen" , "_invoked_via_callable" , "_invoked_via_closest_hit" , "_invoked_via_miss" };
1911
1912    for (int j = 0; j < 4; ++j)
1913    {
1914		TestParams params;
1915
1916        std::string name("callable_shader");
1917
1918        params.invokingShader = invokingShaders[j];
1919        name.append(invokingShadersText[j]);
1920
1921        for (int k = 0; k < 2; ++k)
1922        {
1923			std::string nameFull(name);
1924
1925            params.multipleInvocations = multipleInvocations[k];
1926			nameFull.append(multipleInvocationsText[k]);
1927
1928			group->addChild(new InvokeCallableShaderTestCase(group->getTestContext(), nameFull.c_str(), params));
1929        }
1930    }
1931
1932	return group.release();
1933}
1934
1935}	// RayTracing
1936
1937}	// vkt
1938