1/*-------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2022 The Khronos Group Inc.
6 * Copyright (c) 2022 NVIDIA Corporation.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 *      http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Ray Query Position Fetch Tests
23 *//*--------------------------------------------------------------------*/
24
25#include "vktRayQueryPositionFetchTests.hpp"
26#include "vktTestCase.hpp"
27
28#include "vkRayTracingUtil.hpp"
29#include "vkObjUtil.hpp"
30#include "vkCmdUtil.hpp"
31#include "vkBufferWithMemory.hpp"
32#include "vkBuilderUtil.hpp"
33#include "vkTypeUtil.hpp"
34#include "vkBarrierUtil.hpp"
35#include "vktTestGroupUtil.hpp"
36
37#include "deUniquePtr.hpp"
38#include "deRandom.hpp"
39
40#include "tcuVectorUtil.hpp"
41
42#include <sstream>
43#include <vector>
44#include <iostream>
45
46namespace vkt
47{
48namespace RayQuery
49{
50
51namespace
52{
53
54using namespace vk;
55
56enum ShaderSourcePipeline
57{
58	SSP_GRAPHICS_PIPELINE,
59	SSP_COMPUTE_PIPELINE,
60	SSP_RAY_TRACING_PIPELINE
61};
62
63enum ShaderSourceType
64{
65	SST_VERTEX_SHADER,
66	SST_COMPUTE_SHADER,
67	SST_RAY_GENERATION_SHADER,
68};
69
70enum TestFlagBits
71{
72	TEST_FLAG_BIT_INSTANCE_TRANSFORM				= 1U << 0,
73	TEST_FLAG_BIT_LAST								= 1U << 1,
74};
75
76std::vector<std::string> testFlagBitNames =
77{
78	"instance_transform",
79};
80
81struct TestParams
82{
83	ShaderSourceType		shaderSourceType;
84	ShaderSourcePipeline	shaderSourcePipeline;
85	vk::VkAccelerationStructureBuildTypeKHR	buildType;		// are we making AS on CPU or GPU
86	VkFormat								vertexFormat;
87	deUint32								testFlagMask;
88};
89
90static constexpr deUint32 kNumThreadsAtOnce = 1024;
91
92
93class PositionFetchCase : public TestCase
94{
95public:
96							PositionFetchCase		(tcu::TestContext& testCtx, const std::string& name, const TestParams& params);
97	virtual					~PositionFetchCase	(void) {}
98
99	virtual void			checkSupport				(Context& context) const;
100	virtual void			initPrograms				(vk::SourceCollections& programCollection) const;
101	virtual TestInstance*	createInstance				(Context& context) const;
102
103protected:
104	TestParams				m_params;
105};
106
107class PositionFetchInstance : public TestInstance
108{
109public:
110								PositionFetchInstance		(Context& context, const TestParams& params);
111	virtual						~PositionFetchInstance	(void) {}
112
113	virtual tcu::TestStatus		iterate							(void);
114
115protected:
116	TestParams					m_params;
117};
118
119PositionFetchCase::PositionFetchCase (tcu::TestContext& testCtx, const std::string& name, const TestParams& params)
120	: TestCase	(testCtx, name)
121	, m_params	(params)
122{}
123
124void PositionFetchCase::checkSupport (Context& context) const
125{
126	context.requireDeviceFunctionality("VK_KHR_ray_query");
127	context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
128	context.requireDeviceFunctionality("VK_KHR_ray_tracing_position_fetch");
129
130	const VkPhysicalDeviceRayQueryFeaturesKHR& rayQueryFeaturesKHR = context.getRayQueryFeatures();
131	if (rayQueryFeaturesKHR.rayQuery == DE_FALSE)
132		TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayQueryFeaturesKHR.rayQuery");
133
134	const VkPhysicalDeviceAccelerationStructureFeaturesKHR& accelerationStructureFeaturesKHR = context.getAccelerationStructureFeatures();
135	if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
136		TCU_THROW(TestError, "VK_KHR_ray_query requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
137
138	if (m_params.buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR && accelerationStructureFeaturesKHR.accelerationStructureHostCommands == DE_FALSE)
139		TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructureHostCommands");
140
141	const VkPhysicalDeviceRayTracingPositionFetchFeaturesKHR& rayTracingPositionFetchFeaturesKHR = context.getRayTracingPositionFetchFeatures();
142	if (rayTracingPositionFetchFeaturesKHR.rayTracingPositionFetch == DE_FALSE)
143		TCU_THROW(NotSupportedError, "Requires VkPhysicalDevicePositionFetchFeaturesKHR.rayTracingPositionFetch");
144
145	// Check supported vertex format.
146	checkAccelerationStructureVertexBufferFormat(context.getInstanceInterface(), context.getPhysicalDevice(), m_params.vertexFormat);
147
148	if (m_params.shaderSourceType == SST_RAY_GENERATION_SHADER)
149	{
150		context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
151
152		const VkPhysicalDeviceRayTracingPipelineFeaturesKHR& rayTracingPipelineFeaturesKHR = context.getRayTracingPipelineFeatures();
153
154		if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE)
155			TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
156	}
157
158	switch (m_params.shaderSourceType)
159	{
160	case SST_VERTEX_SHADER:
161		context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
162		break;
163	default:
164		break;
165	}
166}
167
168void PositionFetchCase::initPrograms (vk::SourceCollections& programCollection) const
169{
170	const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
171
172	deUint32 numRays = 1; // XXX
173
174	std::ostringstream sharedHeader;
175	sharedHeader
176		<< "#version 460 core\n"
177		<< "#extension GL_EXT_ray_query : require\n"
178		<< "#extension GL_EXT_ray_tracing_position_fetch : require\n"
179		<< "\n"
180		<< "layout(set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
181		<< "layout(set=0, binding=1, std430) buffer RayOrigins {\n"
182		<< "  vec4 values[" << numRays << "];\n"
183		<< "} origins;\n"
184		<< "layout(set=0, binding=2, std430) buffer OutputPositions {\n"
185		<< "  vec4 values[" << 3*numRays << "];\n"
186		<< "} modes;\n";
187
188	std::ostringstream mainLoop;
189	mainLoop
190		<< "  while (index < " << numRays << ") {\n"
191		//<< "     for (int i=0; i<3; i++) {\n"
192		//<< "       modes.values[3*index.x+i] = vec4(i, 0.0, 5.0, 1.0);\n"
193		//<< "     }\n"
194		<< "    const uint  cullMask  = 0xFF;\n"
195		<< "    const vec3  origin    = origins.values[index].xyz;\n"
196		<< "    const vec3  direction = vec3(0.0, 0.0, -1.0);\n"
197		<< "    const float tMin      = 0.0f;\n"
198		<< "    const float tMax      = 2.0f;\n"
199		<< "    rayQueryEXT rq;\n"
200		<< "    rayQueryInitializeEXT(rq, topLevelAS, gl_RayFlagsNoneEXT, cullMask, origin, tMin, direction, tMax);\n"
201		<< "    while (rayQueryProceedEXT(rq)) {\n"
202		<< "      if (rayQueryGetIntersectionTypeEXT(rq, false) == gl_RayQueryCandidateIntersectionTriangleEXT) {\n"
203		<< "        vec3 outputVal[3];\n"
204		<< "        rayQueryGetIntersectionTriangleVertexPositionsEXT(rq, false, outputVal);\n"
205		<< "        for (int i=0; i<3; i++) {\n"
206		<< "           modes.values[3*index.x+i] = vec4(outputVal[i], 0);\n"
207//		<< "           modes.values[3*index.x+i] = vec4(1.0, 1.0, 1.0, 0);\n"
208		<< "        }\n"
209		<< "      }\n"
210		<< "    }\n"
211		<< "    index += " << kNumThreadsAtOnce << ";\n"
212		<< "  }\n";
213
214	if (m_params.shaderSourceType == SST_VERTEX_SHADER) {
215		std::ostringstream vert;
216		vert
217			<< sharedHeader.str()
218			<< "void main()\n"
219			<< "{\n"
220			<< "  uint index             = gl_VertexIndex.x;\n"
221			<< mainLoop.str()
222			<< "}\n"
223			;
224
225		programCollection.glslSources.add("vert") << glu::VertexSource(vert.str()) << buildOptions;
226	}
227	else if (m_params.shaderSourceType == SST_RAY_GENERATION_SHADER)
228	{
229		std::ostringstream rgen;
230		rgen
231			<< sharedHeader.str()
232			<< "#extension GL_EXT_ray_tracing : require\n"
233			<< "void main()\n"
234			<< "{\n"
235			<< "  uint index             = gl_LaunchIDEXT.x;\n"
236			<< mainLoop.str()
237			<< "}\n"
238			;
239
240		programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
241	}
242	else
243	{
244		DE_ASSERT(m_params.shaderSourceType == SST_COMPUTE_SHADER);
245		std::ostringstream comp;
246		comp
247			<< sharedHeader.str()
248			<< "layout(local_size_x=1024, local_size_y=1, local_size_z=1) in;\n"
249			<< "\n"
250			<< "void main()\n"
251			<< "{\n"
252			<< "  uint index             = gl_LocalInvocationID.x;\n"
253			<< mainLoop.str()
254			<< "}\n"
255			;
256
257		programCollection.glslSources.add("comp") << glu::ComputeSource(updateRayTracingGLSL(comp.str())) << buildOptions;
258	}
259}
260
261TestInstance* PositionFetchCase::createInstance (Context& context) const
262{
263	return new PositionFetchInstance(context, m_params);
264}
265
266PositionFetchInstance::PositionFetchInstance (Context& context, const TestParams& params)
267	: TestInstance	(context)
268	, m_params		(params)
269{}
270
271static Move<VkRenderPass> makeEmptyRenderPass(const DeviceInterface& vk,
272	const VkDevice				device)
273{
274	std::vector<VkSubpassDescription>	subpassDescriptions;
275
276	const VkSubpassDescription	description =
277	{
278		(VkSubpassDescriptionFlags)0,		//  VkSubpassDescriptionFlags		flags;
279		VK_PIPELINE_BIND_POINT_GRAPHICS,	//  VkPipelineBindPoint				pipelineBindPoint;
280		0u,									//  deUint32						inputAttachmentCount;
281		DE_NULL,							//  const VkAttachmentReference*	pInputAttachments;
282		0u,									//  deUint32						colorAttachmentCount;
283		DE_NULL,							//  const VkAttachmentReference*	pColorAttachments;
284		DE_NULL,							//  const VkAttachmentReference*	pResolveAttachments;
285		DE_NULL,							//  const VkAttachmentReference*	pDepthStencilAttachment;
286		0,									//  deUint32						preserveAttachmentCount;
287		DE_NULL								//  const deUint32*					pPreserveAttachments;
288	};
289	subpassDescriptions.push_back(description);
290
291	const VkRenderPassCreateInfo renderPassInfo =
292	{
293		VK_STRUCTURE_TYPE_RENDER_PASS_CREATE_INFO,							//  VkStructureType					sType;
294		DE_NULL,															//  const void*						pNext;
295		static_cast<VkRenderPassCreateFlags>(0u),							//  VkRenderPassCreateFlags			flags;
296		0u,																	//  deUint32						attachmentCount;
297		DE_NULL,															//  const VkAttachmentDescription*	pAttachments;
298		static_cast<deUint32>(subpassDescriptions.size()),					//  deUint32						subpassCount;
299		&subpassDescriptions[0],											//  const VkSubpassDescription*		pSubpasses;
300		0u,																	//  deUint32						dependencyCount;
301		DE_NULL																//  const VkSubpassDependency*		pDependencies;
302	};
303
304	return createRenderPass(vk, device, &renderPassInfo);
305}
306
307static Move<VkFramebuffer> makeFramebuffer (const DeviceInterface& vk, const VkDevice device, VkRenderPass renderPass, uint32_t width, uint32_t height)
308{
309	const vk::VkFramebufferCreateInfo	framebufferParams =
310	{
311		vk::VK_STRUCTURE_TYPE_FRAMEBUFFER_CREATE_INFO,					// sType
312		DE_NULL,														// pNext
313		(vk::VkFramebufferCreateFlags)0,
314		renderPass,														// renderPass
315		0u,																// attachmentCount
316		DE_NULL,														// pAttachments
317		width,															// width
318		height,															// height
319		1u,																// layers
320	};
321
322	return createFramebuffer(vk, device, &framebufferParams);
323}
324
325Move<VkPipeline> makeGraphicsPipeline(const DeviceInterface& vk,
326	const VkDevice				device,
327	const VkPipelineLayout		pipelineLayout,
328	const VkRenderPass			renderPass,
329	const VkShaderModule		vertexModule,
330	const deUint32				subpass)
331{
332	VkExtent2D												renderSize { 256, 256 };
333	VkViewport												viewport = makeViewport(renderSize);
334	VkRect2D												scissor = makeRect2D(renderSize);
335
336	const VkPipelineViewportStateCreateInfo					viewportStateCreateInfo =
337	{
338		VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO,		// VkStructureType                             sType
339		DE_NULL,													// const void*                                 pNext
340		(VkPipelineViewportStateCreateFlags)0,						// VkPipelineViewportStateCreateFlags          flags
341		1u,															// deUint32                                    viewportCount
342		&viewport,													// const VkViewport*                           pViewports
343		1u,															// deUint32                                    scissorCount
344		&scissor													// const VkRect2D*                             pScissors
345	};
346
347	const VkPipelineInputAssemblyStateCreateInfo			inputAssemblyStateCreateInfo =
348	{
349		VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO,	// VkStructureType                            sType
350		DE_NULL,														// const void*                                pNext
351		0u,																// VkPipelineInputAssemblyStateCreateFlags    flags
352		VK_PRIMITIVE_TOPOLOGY_POINT_LIST,								// VkPrimitiveTopology                        topology
353		VK_FALSE														// VkBool32                                   primitiveRestartEnable
354	};
355
356	const VkPipelineVertexInputStateCreateInfo				vertexInputStateCreateInfo =
357	{
358		VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO,									//  VkStructureType									sType
359		DE_NULL,																					//  const void*										pNext
360		(VkPipelineVertexInputStateCreateFlags)0,													//  VkPipelineVertexInputStateCreateFlags			flags
361		0u,																							//  deUint32										vertexBindingDescriptionCount
362		DE_NULL,																					//  const VkVertexInputBindingDescription*			pVertexBindingDescriptions
363		0u,																							//  deUint32										vertexAttributeDescriptionCount
364		DE_NULL,																					//  const VkVertexInputAttributeDescription*		pVertexAttributeDescriptions
365	};
366
367	const VkPipelineRasterizationStateCreateInfo			rasterizationStateCreateInfo =
368	{
369		VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO,	//  VkStructureType							sType
370		DE_NULL,													//  const void*								pNext
371		0u,															//  VkPipelineRasterizationStateCreateFlags	flags
372		VK_FALSE,													//  VkBool32								depthClampEnable
373		VK_TRUE,													//  VkBool32								rasterizerDiscardEnable
374		VK_POLYGON_MODE_FILL,										//  VkPolygonMode							polygonMode
375		VK_CULL_MODE_NONE,											//  VkCullModeFlags							cullMode
376		VK_FRONT_FACE_COUNTER_CLOCKWISE,							//  VkFrontFace								frontFace
377		VK_FALSE,													//  VkBool32								depthBiasEnable
378		0.0f,														//  float									depthBiasConstantFactor
379		0.0f,														//  float									depthBiasClamp
380		0.0f,														//  float									depthBiasSlopeFactor
381		1.0f														//  float									lineWidth
382	};
383
384	return makeGraphicsPipeline(vk,									// const DeviceInterface&							vk
385		device,								// const VkDevice									device
386		pipelineLayout,						// const VkPipelineLayout							pipelineLayout
387		vertexModule,						// const VkShaderModule								vertexShaderModule
388		DE_NULL,							// const VkShaderModule								tessellationControlModule
389		DE_NULL,							// const VkShaderModule								tessellationEvalModule
390		DE_NULL,							// const VkShaderModule								geometryShaderModule
391		DE_NULL,							// const VkShaderModule								fragmentShaderModule
392		renderPass,							// const VkRenderPass								renderPass
393		subpass,							// const deUint32									subpass
394		&vertexInputStateCreateInfo,		// const VkPipelineVertexInputStateCreateInfo*		vertexInputStateCreateInfo
395		&inputAssemblyStateCreateInfo,		// const VkPipelineInputAssemblyStateCreateInfo*	inputAssemblyStateCreateInfo
396		DE_NULL,							// const VkPipelineTessellationStateCreateInfo*		tessStateCreateInfo
397		&viewportStateCreateInfo,			// const VkPipelineViewportStateCreateInfo*			viewportStateCreateInfo
398		&rasterizationStateCreateInfo);	// const VkPipelineRasterizationStateCreateInfo*	rasterizationStateCreateInfo
399}
400
401tcu::TestStatus PositionFetchInstance::iterate (void)
402{
403	const auto&	vkd		= m_context.getDeviceInterface();
404	const auto	device	= m_context.getDevice();
405	auto&		alloc	= m_context.getDefaultAllocator();
406	const auto	qIndex	= m_context.getUniversalQueueFamilyIndex();
407	const auto	queue	= m_context.getUniversalQueue();
408
409	// Command pool and buffer.
410	const auto cmdPool		= makeCommandPool(vkd, device, qIndex);
411	const auto cmdBufferPtr	= allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
412	const auto cmdBuffer	= cmdBufferPtr.get();
413
414	beginCommandBuffer(vkd, cmdBuffer);
415
416	// Build acceleration structures.
417	auto topLevelAS		= makeTopLevelAccelerationStructure();
418	auto bottomLevelAS	= makeBottomLevelAccelerationStructure();
419
420	const std::vector<tcu::Vec3> triangle =
421	{
422		tcu::Vec3(0.0f, 0.0f, 0.0f),
423		tcu::Vec3(1.0f, 0.0f, 0.0f),
424		tcu::Vec3(0.0f, 1.0f, 0.0f),
425	};
426
427	const VkTransformMatrixKHR notQuiteIdentityMatrix3x4 = { { { 0.98f, 0.0f, 0.0f, 0.0f }, { 0.0f, 0.97f, 0.0f, 0.0f }, { 0.0f, 0.0f, 0.99f, 0.0f } } };
428
429	de::SharedPtr<RaytracedGeometryBase> geometry = makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, m_params.vertexFormat, VK_INDEX_TYPE_NONE_KHR);
430
431	for (auto & v : triangle) {
432		geometry->addVertex(v);
433	}
434
435	bottomLevelAS->addGeometry(geometry);
436	bottomLevelAS->setBuildFlags(VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_DATA_ACCESS_KHR);
437	bottomLevelAS->setBuildType(m_params.buildType);
438	bottomLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
439	de::SharedPtr<BottomLevelAccelerationStructure> blasSharedPtr (bottomLevelAS.release());
440
441	topLevelAS->setInstanceCount(1);
442	topLevelAS->setBuildType(m_params.buildType);
443	topLevelAS->addInstance(blasSharedPtr, (m_params.testFlagMask & TEST_FLAG_BIT_INSTANCE_TRANSFORM) ? notQuiteIdentityMatrix3x4 : identityMatrix3x4);
444	topLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
445
446	// One ray for this test
447	// XXX Should it be multiple triangles and one ray per triangle for more coverage?
448	// XXX If it's really one ray, the origin buffer is complete overkill
449	deUint32 numRays = 1; // XXX
450
451	// SSBO buffer for origins.
452	const auto originsBufferSize		= static_cast<VkDeviceSize>(sizeof(tcu::Vec4) * numRays);
453	const auto originsBufferInfo		= makeBufferCreateInfo(originsBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
454	BufferWithMemory originsBuffer	(vkd, device, alloc, originsBufferInfo, MemoryRequirement::HostVisible);
455	auto& originsBufferAlloc			= originsBuffer.getAllocation();
456	void* originsBufferData				= originsBufferAlloc.getHostPtr();
457
458	std::vector<tcu::Vec4> origins;
459	std::vector<tcu::Vec3> expectedOutputPositions;
460	origins.reserve(numRays);
461	expectedOutputPositions.reserve(3*numRays);
462
463	// Fill in vector of expected outputs
464	for (deUint32 index = 0; index < numRays; index++) {
465		for (deUint32 vert = 0; vert < 3; vert++) {
466			tcu::Vec3 pos = triangle[vert];
467
468			expectedOutputPositions.push_back(pos);
469		}
470	}
471
472	// XXX Arbitrary location and see above
473	for (deUint32 index = 0; index < numRays; index++) {
474		origins.push_back(tcu::Vec4(0.25, 0.25, 1.0, 0.0));
475	}
476
477	const auto				originsBufferSizeSz = static_cast<size_t>(originsBufferSize);
478	deMemcpy(originsBufferData, origins.data(), originsBufferSizeSz);
479	flushAlloc(vkd, device, originsBufferAlloc);
480
481	// Storage buffer for output modes
482	const auto outputPositionsBufferSize = static_cast<VkDeviceSize>(3 * 4 * sizeof(float) * numRays);
483	const auto outputPositionsBufferInfo = makeBufferCreateInfo(outputPositionsBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
484	BufferWithMemory outputPositionsBuffer(vkd, device, alloc, outputPositionsBufferInfo, MemoryRequirement::HostVisible);
485	auto& outputPositionsBufferAlloc = outputPositionsBuffer.getAllocation();
486	void* outputPositionsBufferData = outputPositionsBufferAlloc.getHostPtr();
487	deMemset(outputPositionsBufferData, 0xFF, static_cast<size_t>(outputPositionsBufferSize));
488	flushAlloc(vkd, device, outputPositionsBufferAlloc);
489
490	// Descriptor set layout.
491	DescriptorSetLayoutBuilder dsLayoutBuilder;
492	dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, VK_SHADER_STAGE_ALL);
493	dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_SHADER_STAGE_ALL);
494	dsLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_SHADER_STAGE_ALL);
495	const auto setLayout = dsLayoutBuilder.build(vkd, device);
496
497	// Pipeline layout.
498	const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
499
500	// Descriptor pool and set.
501	DescriptorPoolBuilder poolBuilder;
502	poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
503	poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
504	poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
505	const auto descriptorPool	= poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
506	const auto descriptorSet	= makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
507
508	// Update descriptor set.
509	{
510		const VkWriteDescriptorSetAccelerationStructureKHR accelDescInfo =
511		{
512			VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
513			nullptr,
514			1u,
515			topLevelAS.get()->getPtr(),
516		};
517		const auto inStorageBufferInfo = makeDescriptorBufferInfo(originsBuffer.get(), 0ull, VK_WHOLE_SIZE);
518		const auto storageBufferInfo = makeDescriptorBufferInfo(outputPositionsBuffer.get(), 0ull, VK_WHOLE_SIZE);
519
520		DescriptorSetUpdateBuilder updateBuilder;
521		updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelDescInfo);
522		updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inStorageBufferInfo);
523		updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(2u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &storageBufferInfo);
524		updateBuilder.update(vkd, device);
525	}
526
527	Move<VkPipeline>				pipeline;
528	de::MovePtr<BufferWithMemory>	raygenSBT;
529	Move<VkRenderPass>				renderPass;
530	Move<VkFramebuffer>				framebuffer;
531
532	if (m_params.shaderSourceType == SST_VERTEX_SHADER)
533	{
534		auto vertexModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("vert"), 0);
535
536		const uint32_t width = 32u;
537		const uint32_t height = 32u;
538		renderPass = makeEmptyRenderPass(vkd, device);
539		framebuffer = makeFramebuffer(vkd, device, *renderPass, width, height);
540		pipeline = makeGraphicsPipeline(vkd, device, *pipelineLayout, *renderPass, *vertexModule, 0);
541
542		const VkRenderPassBeginInfo			renderPassBeginInfo =
543		{
544			VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO,							// VkStructureType								sType;
545			DE_NULL,															// const void*									pNext;
546			*renderPass,														// VkRenderPass									renderPass;
547			*framebuffer,														// VkFramebuffer								framebuffer;
548			makeRect2D(width, height),											// VkRect2D										renderArea;
549			0u,																	// uint32_t										clearValueCount;
550			DE_NULL																// const VkClearValue*							pClearValues;
551		};
552
553		vkd.cmdBeginRenderPass(cmdBuffer, &renderPassBeginInfo, VK_SUBPASS_CONTENTS_INLINE);
554		vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline.get());
555		vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipelineLayout.get(), 0u, 1u, &descriptorSet.get(), 0u, nullptr);
556		vkd.cmdDraw(cmdBuffer, kNumThreadsAtOnce, 1, 0, 0);
557		vkd.cmdEndRenderPass(cmdBuffer);
558	}
559	else if (m_params.shaderSourceType == SST_RAY_GENERATION_SHADER)
560	{
561		const auto& vki = m_context.getInstanceInterface();
562		const auto	physDev = m_context.getPhysicalDevice();
563
564		// Shader module.
565		auto rgenModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0);
566
567		// Get some ray tracing properties.
568		deUint32 shaderGroupHandleSize = 0u;
569		deUint32 shaderGroupBaseAlignment = 1u;
570		{
571			const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physDev);
572			shaderGroupHandleSize = rayTracingPropertiesKHR->getShaderGroupHandleSize();
573			shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
574		}
575
576		auto raygenSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
577		auto unusedSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
578
579		{
580			const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
581			rayTracingPipeline->setCreateFlags(VK_PIPELINE_CREATE_RAY_TRACING_OPACITY_MICROMAP_BIT_EXT);
582			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, rgenModule, 0);
583
584			pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
585
586			raygenSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
587			raygenSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
588		}
589
590		// Trace rays.
591		vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
592		vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u, &descriptorSet.get(), 0u, nullptr);
593		vkd.cmdTraceRaysKHR(cmdBuffer, &raygenSBTRegion, &unusedSBTRegion, &unusedSBTRegion, &unusedSBTRegion, kNumThreadsAtOnce, 1u, 1u);
594	}
595	else
596	{
597		DE_ASSERT(m_params.shaderSourceType == SST_COMPUTE_SHADER);
598		// Shader module.
599		const auto compModule = createShaderModule(vkd, device, m_context.getBinaryCollection().get("comp"), 0);
600
601		// Pipeline.
602		const VkPipelineShaderStageCreateInfo shaderInfo =
603		{
604			VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,	//	VkStructureType						sType;
605			nullptr,												//	const void*							pNext;
606			0u,														//	VkPipelineShaderStageCreateFlags	flags;
607			VK_SHADER_STAGE_COMPUTE_BIT,							//	VkShaderStageFlagBits				stage;
608			compModule.get(),										//	VkShaderModule						module;
609			"main",													//	const char*							pName;
610			nullptr,												//	const VkSpecializationInfo*			pSpecializationInfo;
611		};
612		const VkComputePipelineCreateInfo pipelineInfo =
613		{
614			VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,	//	VkStructureType					sType;
615			nullptr,										//	const void*						pNext;
616			0u,												//	VkPipelineCreateFlags			flags;
617			shaderInfo,										//	VkPipelineShaderStageCreateInfo	stage;
618			pipelineLayout.get(),							//	VkPipelineLayout				layout;
619			DE_NULL,										//	VkPipeline						basePipelineHandle;
620			0,												//	deInt32							basePipelineIndex;
621		};
622		pipeline = createComputePipeline(vkd, device, DE_NULL, &pipelineInfo);
623
624		// Dispatch work with ray queries.
625		vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get());
626		vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout.get(), 0u, 1u, &descriptorSet.get(), 0u, nullptr);
627		vkd.cmdDispatch(cmdBuffer, 1u, 1u, 1u);
628	}
629
630	// Barrier for the output buffer.
631	const auto bufferBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
632	vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u, &bufferBarrier, 0u, nullptr, 0u, nullptr);
633
634	endCommandBuffer(vkd, cmdBuffer);
635	submitCommandsAndWait(vkd, device, queue, cmdBuffer);
636
637	// Verify results.
638	std::vector<tcu::Vec4>	outputData(expectedOutputPositions.size());
639	const auto				outputPositionsBufferSizeSz = static_cast<size_t>(outputPositionsBufferSize);
640
641	invalidateAlloc(vkd, device, outputPositionsBufferAlloc);
642	DE_ASSERT(de::dataSize(outputData) == outputPositionsBufferSizeSz);
643	deMemcpy(outputData.data(), outputPositionsBufferData, outputPositionsBufferSizeSz);
644
645	for (size_t i = 0; i < outputData.size(); ++i)
646	{
647		/*const */ auto& outVal = outputData[i]; // Should be const but .xyz() isn't
648		tcu::Vec3 outVec3 = outVal.xyz();
649		const auto& expectedVal = expectedOutputPositions[i];
650		const auto& diff = expectedOutputPositions[i] - outVec3;
651		float len = dot(diff, diff);
652
653		// XXX Find a better epsilon
654		if (!(len < 1e-5))
655		{
656			std::ostringstream msg;
657			msg << "Unexpected value found for element " << i << ": expected " << expectedVal << " and found " << outVal << ";";
658			TCU_FAIL(msg.str());
659		}
660#if 0
661		else
662		{
663			std::ostringstream msg;
664			msg << "Expected value found for element " << i << ": expected " << expectedVal << " and found " << outVal << ";\n";
665			std::cout << msg.str();
666		}
667#endif
668	}
669
670	return tcu::TestStatus::pass("Pass");
671}
672
673} // anonymous
674
675tcu::TestCaseGroup*	createPositionFetchTests (tcu::TestContext& testCtx)
676{
677	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "position_fetch"));
678
679	struct
680	{
681		vk::VkAccelerationStructureBuildTypeKHR				buildType;
682		const char* name;
683	} buildTypes[] =
684	{
685		{ VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR,	"cpu_built"	},
686		{ VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR,	"gpu_built"	},
687	};
688
689
690	const struct
691	{
692		ShaderSourceType						shaderSourceType;
693		ShaderSourcePipeline					shaderSourcePipeline;
694		std::string								name;
695	} shaderSourceTypes[] =
696	{
697		{ SST_VERTEX_SHADER,					SSP_GRAPHICS_PIPELINE,		"vertex_shader"				},
698		{ SST_COMPUTE_SHADER,					SSP_COMPUTE_PIPELINE,		"compute_shader",			},
699		{ SST_RAY_GENERATION_SHADER,			SSP_RAY_TRACING_PIPELINE,	"rgen_shader",				},
700	};
701
702	const VkFormat vertexFormats[] =
703	{
704		// Mandatory formats.
705		VK_FORMAT_R32G32_SFLOAT,
706		VK_FORMAT_R32G32B32_SFLOAT,
707		VK_FORMAT_R16G16_SFLOAT,
708		VK_FORMAT_R16G16B16A16_SFLOAT,
709		VK_FORMAT_R16G16_SNORM,
710		VK_FORMAT_R16G16B16A16_SNORM,
711
712		// Additional formats.
713		VK_FORMAT_R8G8_SNORM,
714		VK_FORMAT_R8G8B8_SNORM,
715		VK_FORMAT_R8G8B8A8_SNORM,
716		VK_FORMAT_R16G16B16_SNORM,
717		VK_FORMAT_R16G16B16_SFLOAT,
718		VK_FORMAT_R32G32B32A32_SFLOAT,
719		VK_FORMAT_R64G64_SFLOAT,
720		VK_FORMAT_R64G64B64_SFLOAT,
721		VK_FORMAT_R64G64B64A64_SFLOAT,
722	};
723
724	for (size_t shaderSourceNdx = 0; shaderSourceNdx < DE_LENGTH_OF_ARRAY(shaderSourceTypes); ++shaderSourceNdx)
725	{
726		de::MovePtr<tcu::TestCaseGroup> sourceTypeGroup(new tcu::TestCaseGroup(group->getTestContext(), shaderSourceTypes[shaderSourceNdx].name.c_str()));
727
728		for (size_t buildTypeNdx = 0; buildTypeNdx < DE_LENGTH_OF_ARRAY(buildTypes); ++buildTypeNdx)
729		{
730			de::MovePtr<tcu::TestCaseGroup> buildGroup(new tcu::TestCaseGroup(group->getTestContext(), buildTypes[buildTypeNdx].name));
731
732			for (size_t vertexFormatNdx = 0; vertexFormatNdx < DE_LENGTH_OF_ARRAY(vertexFormats); ++vertexFormatNdx)
733			{
734				const auto format = vertexFormats[vertexFormatNdx];
735				const auto formatName = getFormatSimpleName(format);
736
737				de::MovePtr<tcu::TestCaseGroup> vertexFormatGroup(new tcu::TestCaseGroup(group->getTestContext(), formatName.c_str()));
738
739				for (deUint32 testFlagMask = 0; testFlagMask < TEST_FLAG_BIT_LAST; testFlagMask++)
740				{
741					std::string maskName = "";
742
743					for (deUint32 bit = 0; bit < testFlagBitNames.size(); bit++)
744					{
745						if (testFlagMask & (1 << bit))
746						{
747							if (maskName != "")
748								maskName += "_";
749							maskName += testFlagBitNames[bit];
750						}
751					}
752					if (maskName == "")
753						maskName = "NoFlags";
754
755					de::MovePtr<tcu::TestCaseGroup> testFlagGroup(new tcu::TestCaseGroup(group->getTestContext(), maskName.c_str()));
756
757					TestParams testParams
758					{
759						shaderSourceTypes[shaderSourceNdx].shaderSourceType,
760						shaderSourceTypes[shaderSourceNdx].shaderSourcePipeline,
761						buildTypes[buildTypeNdx].buildType,
762						format,
763						testFlagMask,
764					};
765
766					vertexFormatGroup->addChild(new PositionFetchCase(testCtx, maskName, testParams));
767				}
768				buildGroup->addChild(vertexFormatGroup.release());
769			}
770			sourceTypeGroup->addChild(buildGroup.release());
771		}
772		group->addChild(sourceTypeGroup.release());
773	}
774
775	return group.release();
776}
777} // RayQuery
778} // vkt
779