From 33f4a2e6c55cf3df7e31ec1b740e1f309073c324 Mon Sep 17 00:00:00 2001 From: Swung0x48 Date: Mon, 21 Apr 2025 10:34:35 +0800 Subject: [PATCH] [Fix] (multidraw): allocate output buffer --- src/main/cpp/gl/multidraw.cpp | 26 +++++++++++++++------ src/main/cpp/shaders/multidraw_compute.comp | 12 +++++----- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/main/cpp/gl/multidraw.cpp b/src/main/cpp/gl/multidraw.cpp index 7f8bed3..bb4c300 100644 --- a/src/main/cpp/gl/multidraw.cpp +++ b/src/main/cpp/gl/multidraw.cpp @@ -263,17 +263,16 @@ layout(std430, binding = 2) readonly buffer Prefix { uint prefixSums[]; }; layout(std430, binding = 3) writeonly buffer Output { uint out_indices[]; }; void main() { - uint globalIdx = gl_GlobalInvocationID.x; - if (globalIdx >= prefixSums[prefixSums.length() - 1]) + uint outIdx = gl_GlobalInvocationID.x; + if (outIdx >= prefixSums[prefixSums.length() - 1]) return; - out_indices[globalIdx] = globalIdx; // bisect to find out draw call # int low = 0; int high = draws.length() - 1; while(low < high) { int mid = (low + high + 1) / 2; - if (prefixSums[mid] <= globalIdx) { + if (prefixSums[mid] <= outIdx) { low = mid; } else { high = mid - 1; @@ -282,11 +281,11 @@ void main() { // figure out which index to take DrawCommand cmd = draws[low]; - uint localIdx = globalIdx - prefixSums[low]; + uint localIdx = outIdx - prefixSums[low]; uint srcIndex = cmd.firstIndex + localIdx; // Write out - out_indices[globalIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex); + out_indices[outIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex); } )"; @@ -382,6 +381,15 @@ GLAPI GLAPIENTRY void mg_glMultiDrawElementsBaseVertex_compute( GLES.glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); CHECK_GL_ERROR_NO_INIT + // Allocate output buffer + auto total_indices = g_prefix_sum[primcount - 1]; + GLES.glBindBuffer(GL_SHADER_STORAGE_BUFFER, g_outputibo); + CHECK_GL_ERROR_NO_INIT + GLES.glBufferData(GL_SHADER_STORAGE_BUFFER, sizeof(GLuint) * total_indices, nullptr, GL_DYNAMIC_DRAW); + CHECK_GL_ERROR_NO_INIT + GLES.glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); + CHECK_GL_ERROR_NO_INIT + GLint ibo = 0; GLES.glGetIntegerv(GL_ELEMENT_ARRAY_BUFFER_BINDING, &ibo); CHECK_GL_ERROR_NO_INIT @@ -396,11 +404,15 @@ GLAPI GLAPIENTRY void mg_glMultiDrawElementsBaseVertex_compute( GLES.glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 3, g_outputibo); CHECK_GL_ERROR_NO_INIT + // Save states + GLint prev_program = 0; + GLES.glGetIntegerv(GL_CURRENT_PROGRAM, &prev_program); + CHECK_GL_ERROR_NO_INIT + // Dispatch compute LOG_D("Using compute program = %d", g_compute_program) GLES.glUseProgram(g_compute_program); CHECK_GL_ERROR_NO_INIT - GLuint total_indices = g_prefix_sum[primcount - 1]; LOG_D("Dispatch compute") GLES.glDispatchCompute((total_indices + 63) / 64, 1, 1); CHECK_GL_ERROR_NO_INIT diff --git a/src/main/cpp/shaders/multidraw_compute.comp b/src/main/cpp/shaders/multidraw_compute.comp index f77a4aa..ddb9fb8 100644 --- a/src/main/cpp/shaders/multidraw_compute.comp +++ b/src/main/cpp/shaders/multidraw_compute.comp @@ -1,6 +1,6 @@ #version 310 es -layout(local_size_x = 256) in; +layout(local_size_x = 64) in; struct DrawCommand { uint count; @@ -16,8 +16,8 @@ layout(std430, binding = 2) readonly buffer Prefix { uint prefixSums[]; }; layout(std430, binding = 3) writeonly buffer Output { uint out_indices[]; }; void main() { - uint globalIdx = gl_GlobalInvocationID.x; - if (globalIdx >= prefixSums[prefixSums.length() - 1]) + uint outIdx = gl_GlobalInvocationID.x; + if (outIdx >= prefixSums[prefixSums.length() - 1]) return; // bisect to find out draw call # @@ -25,7 +25,7 @@ void main() { int high = draws.length() - 1; while(low < high) { int mid = (low + high + 1) / 2; - if (prefixSums[mid] <= globalIdx) { + if (prefixSums[mid] <= outIdx) { low = mid; } else { high = mid - 1; @@ -34,9 +34,9 @@ void main() { // figure out which index to take DrawCommand cmd = draws[low]; - uint localIdx = globalIdx - prefixSums[low]; + uint localIdx = outIdx - prefixSums[low]; uint srcIndex = cmd.firstIndex + localIdx; // Write out - out_indices[globalIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex); + out_indices[outIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex); }