[Fix] (multidraw): allocate output buffer

This commit is contained in:
Swung0x48 2025-04-21 10:34:35 +08:00
parent cd47992b41
commit 33f4a2e6c5
2 changed files with 25 additions and 13 deletions

View File

@ -263,17 +263,16 @@ layout(std430, binding = 2) readonly buffer Prefix { uint prefixSums[]; };
layout(std430, binding = 3) writeonly buffer Output { uint out_indices[]; }; layout(std430, binding = 3) writeonly buffer Output { uint out_indices[]; };
void main() { void main() {
uint globalIdx = gl_GlobalInvocationID.x; uint outIdx = gl_GlobalInvocationID.x;
if (globalIdx >= prefixSums[prefixSums.length() - 1]) if (outIdx >= prefixSums[prefixSums.length() - 1])
return; return;
out_indices[globalIdx] = globalIdx;
// bisect to find out draw call # // bisect to find out draw call #
int low = 0; int low = 0;
int high = draws.length() - 1; int high = draws.length() - 1;
while(low < high) { while(low < high) {
int mid = (low + high + 1) / 2; int mid = (low + high + 1) / 2;
if (prefixSums[mid] <= globalIdx) { if (prefixSums[mid] <= outIdx) {
low = mid; low = mid;
} else { } else {
high = mid - 1; high = mid - 1;
@ -282,11 +281,11 @@ void main() {
// figure out which index to take // figure out which index to take
DrawCommand cmd = draws[low]; DrawCommand cmd = draws[low];
uint localIdx = globalIdx - prefixSums[low]; uint localIdx = outIdx - prefixSums[low];
uint srcIndex = cmd.firstIndex + localIdx; uint srcIndex = cmd.firstIndex + localIdx;
// Write out // Write out
out_indices[globalIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex); out_indices[outIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex);
} }
)"; )";
@ -382,6 +381,15 @@ GLAPI GLAPIENTRY void mg_glMultiDrawElementsBaseVertex_compute(
GLES.glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); GLES.glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
CHECK_GL_ERROR_NO_INIT CHECK_GL_ERROR_NO_INIT
// Allocate output buffer
auto total_indices = g_prefix_sum[primcount - 1];
GLES.glBindBuffer(GL_SHADER_STORAGE_BUFFER, g_outputibo);
CHECK_GL_ERROR_NO_INIT
GLES.glBufferData(GL_SHADER_STORAGE_BUFFER, sizeof(GLuint) * total_indices, nullptr, GL_DYNAMIC_DRAW);
CHECK_GL_ERROR_NO_INIT
GLES.glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
CHECK_GL_ERROR_NO_INIT
GLint ibo = 0; GLint ibo = 0;
GLES.glGetIntegerv(GL_ELEMENT_ARRAY_BUFFER_BINDING, &ibo); GLES.glGetIntegerv(GL_ELEMENT_ARRAY_BUFFER_BINDING, &ibo);
CHECK_GL_ERROR_NO_INIT CHECK_GL_ERROR_NO_INIT
@ -396,11 +404,15 @@ GLAPI GLAPIENTRY void mg_glMultiDrawElementsBaseVertex_compute(
GLES.glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 3, g_outputibo); GLES.glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 3, g_outputibo);
CHECK_GL_ERROR_NO_INIT CHECK_GL_ERROR_NO_INIT
// Save states
GLint prev_program = 0;
GLES.glGetIntegerv(GL_CURRENT_PROGRAM, &prev_program);
CHECK_GL_ERROR_NO_INIT
// Dispatch compute // Dispatch compute
LOG_D("Using compute program = %d", g_compute_program) LOG_D("Using compute program = %d", g_compute_program)
GLES.glUseProgram(g_compute_program); GLES.glUseProgram(g_compute_program);
CHECK_GL_ERROR_NO_INIT CHECK_GL_ERROR_NO_INIT
GLuint total_indices = g_prefix_sum[primcount - 1];
LOG_D("Dispatch compute") LOG_D("Dispatch compute")
GLES.glDispatchCompute((total_indices + 63) / 64, 1, 1); GLES.glDispatchCompute((total_indices + 63) / 64, 1, 1);
CHECK_GL_ERROR_NO_INIT CHECK_GL_ERROR_NO_INIT

View File

@ -1,6 +1,6 @@
#version 310 es #version 310 es
layout(local_size_x = 256) in; layout(local_size_x = 64) in;
struct DrawCommand { struct DrawCommand {
uint count; uint count;
@ -16,8 +16,8 @@ layout(std430, binding = 2) readonly buffer Prefix { uint prefixSums[]; };
layout(std430, binding = 3) writeonly buffer Output { uint out_indices[]; }; layout(std430, binding = 3) writeonly buffer Output { uint out_indices[]; };
void main() { void main() {
uint globalIdx = gl_GlobalInvocationID.x; uint outIdx = gl_GlobalInvocationID.x;
if (globalIdx >= prefixSums[prefixSums.length() - 1]) if (outIdx >= prefixSums[prefixSums.length() - 1])
return; return;
// bisect to find out draw call # // bisect to find out draw call #
@ -25,7 +25,7 @@ void main() {
int high = draws.length() - 1; int high = draws.length() - 1;
while(low < high) { while(low < high) {
int mid = (low + high + 1) / 2; int mid = (low + high + 1) / 2;
if (prefixSums[mid] <= globalIdx) { if (prefixSums[mid] <= outIdx) {
low = mid; low = mid;
} else { } else {
high = mid - 1; high = mid - 1;
@ -34,9 +34,9 @@ void main() {
// figure out which index to take // figure out which index to take
DrawCommand cmd = draws[low]; DrawCommand cmd = draws[low];
uint localIdx = globalIdx - prefixSums[low]; uint localIdx = outIdx - prefixSums[low];
uint srcIndex = cmd.firstIndex + localIdx; uint srcIndex = cmd.firstIndex + localIdx;
// Write out // Write out
out_indices[globalIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex); out_indices[outIdx] = uint(int(in_indices[srcIndex]) + cmd.baseVertex);
} }