From 2571778a4e5e9ba6ec3c9f64ee016c4259e03e84 Mon Sep 17 00:00:00 2001 From: asuessenbach Date: Wed, 12 Aug 2020 11:38:42 +0200 Subject: [PATCH] [Samples][Ray Tracing] Correct offsets/sizes in the shader binding table. --- samples/RayTracing/RayTracing.cpp | 53 ++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/samples/RayTracing/RayTracing.cpp b/samples/RayTracing/RayTracing.cpp index 1df8530..3e29c71 100644 --- a/samples/RayTracing/RayTracing.cpp +++ b/samples/RayTracing/RayTracing.cpp @@ -640,6 +640,11 @@ glm::vec3 randomVec3( float minValue, float maxValue ) randomDistribution( randomGenerator ) ); } +size_t roundUp( size_t value, size_t alignment ) +{ + return ( ( value + alignment - 1 ) / alignment ) * alignment; +} + int main( int /*argc*/, char ** /*argv*/ ) { // number of cubes in x-, y-, and z-direction @@ -1097,7 +1102,7 @@ int main( int /*argc*/, char ** /*argv*/ ) uint32_t maxRecursionDepth = 2; vk::RayTracingPipelineCreateInfoNV rayTracingPipelineCreateInfo( {}, shaderStages, shaderGroups, maxRecursionDepth, *rayTracingPipelineLayout ); - vk::UniquePipeline rayTracingPipeline; + vk::UniquePipeline rayTracingPipeline; vk::ResultValue rvPipeline = device->createRayTracingPipelineNVUnique( nullptr, rayTracingPipelineCreateInfo ); switch ( rvPipeline.result ) @@ -1109,16 +1114,32 @@ int main( int /*argc*/, char ** /*argv*/ ) default: assert( false ); // should never happen } + vk::StructureChain propertiesChain = + physicalDevice.getProperties2(); + uint32_t shaderGroupBaseAlignment = + propertiesChain.get().shaderGroupBaseAlignment; uint32_t shaderGroupHandleSize = - physicalDevice.getProperties2() - .get() - .shaderGroupHandleSize; - assert( !( shaderGroupHandleSize % 16 ) ); - uint32_t shaderBindingTableSize = 5 * shaderGroupHandleSize; // 1x raygen, 2x miss, 2x hitGroup + propertiesChain.get().shaderGroupHandleSize; - // with 5 shaders, we need a buffer to hold 5 shaderGroupHandles + vk::DeviceSize raygenShaderBindingOffset = 0; // starting with raygen + uint32_t raygenShaderTableSize = shaderGroupHandleSize; // one raygen shader + vk::DeviceSize missShaderBindingOffset = + raygenShaderBindingOffset + roundUp( raygenShaderTableSize, shaderGroupBaseAlignment ); + vk::DeviceSize missShaderBindingStride = shaderGroupHandleSize; + uint32_t missShaderTableSize = vk::su::checked_cast( 2 * missShaderBindingStride ); // two raygen shaders + vk::DeviceSize hitShaderBindingOffset = + missShaderBindingOffset + roundUp( missShaderTableSize, shaderGroupBaseAlignment ); + vk::DeviceSize hitShaderBindingStride = shaderGroupHandleSize; + uint32_t hitShaderTableSize = vk::su::checked_cast( 2 * hitShaderBindingStride ); // two hit shaders + + vk::DeviceSize shaderBindingTableSize = hitShaderBindingOffset + hitShaderTableSize; std::vector shaderHandleStorage( shaderBindingTableSize ); - device->getRayTracingShaderGroupHandlesNV( *rayTracingPipeline, 0, 5, shaderHandleStorage ); + device->getRayTracingShaderGroupHandlesNV( + *rayTracingPipeline, 0, 1, { raygenShaderTableSize, &shaderHandleStorage[raygenShaderBindingOffset] } ); + device->getRayTracingShaderGroupHandlesNV( + *rayTracingPipeline, 1, 2, { missShaderTableSize, &shaderHandleStorage[missShaderBindingOffset] } ); + device->getRayTracingShaderGroupHandlesNV( + *rayTracingPipeline, 3, 2, { hitShaderTableSize, &shaderHandleStorage[hitShaderBindingOffset] } ); vk::su::BufferData shaderBindingTableBufferData( physicalDevice, device, @@ -1250,20 +1271,14 @@ int main( int /*argc*/, char ** /*argv*/ ) *rayTracingDescriptorSets[backBufferIndex], nullptr ); - VkDeviceSize rayGenOffset = 0; // starting with raygen - VkDeviceSize missOffset = shaderGroupHandleSize; // after raygen - VkDeviceSize missStride = shaderGroupHandleSize; - VkDeviceSize hitGroupOffset = shaderGroupHandleSize + 2 * shaderGroupHandleSize; // after 1x raygen and 2x miss - VkDeviceSize hitGroupStride = shaderGroupHandleSize; - commandBuffer->traceRaysNV( *shaderBindingTableBufferData.buffer, - rayGenOffset, + raygenShaderBindingOffset, *shaderBindingTableBufferData.buffer, - missOffset, - missStride, + missShaderBindingOffset, + missShaderBindingStride, *shaderBindingTableBufferData.buffer, - hitGroupOffset, - hitGroupStride, + hitShaderBindingOffset, + hitShaderBindingStride, nullptr, 0, 0,