diff --git a/Makefile b/Makefile index 1454fcd..eafed43 100644 --- a/Makefile +++ b/Makefile @@ -215,6 +215,7 @@ TEST_PROG_COMMON_SRC := programs/test_util.c TEST_PROG_SRC := programs/benchmark.c \ programs/checksum.c \ programs/test_checksums.c \ + programs/test_custom_malloc.c \ programs/test_incomplete_codes.c \ programs/test_slow_decompression.c \ programs/test_trailing_bytes.c diff --git a/lib/utils.c b/lib/utils.c index 0cb81cc..2d128d2 100644 --- a/lib/utils.c +++ b/lib/utils.c @@ -29,16 +29,21 @@ #include "lib_common.h" +#include "libdeflate.h" + +static void *(*libdeflate_malloc_func)(size_t) = malloc; +static void (*libdeflate_free_func)(void *) = free; + void * libdeflate_malloc(size_t size) { - return malloc(size); + return (*libdeflate_malloc_func)(size); } void libdeflate_free(void *ptr) { - free(ptr); + (*libdeflate_free_func)(ptr); } void * @@ -59,3 +64,11 @@ libdeflate_aligned_free(void *ptr) if (ptr) libdeflate_free(((void **)ptr)[-1]); } + +LIBDEFLATEEXPORT void LIBDEFLATEAPI +libdeflate_set_memory_allocator(void *(*malloc_func)(size_t), + void (*free_func)(void *)) +{ + libdeflate_malloc_func = malloc_func; + libdeflate_free_func = free_func; +} diff --git a/libdeflate.h b/libdeflate.h index 27d8e87..94ddeff 100644 --- a/libdeflate.h +++ b/libdeflate.h @@ -339,6 +339,22 @@ libdeflate_adler32(uint32_t adler32, const void *buffer, size_t len); LIBDEFLATEEXPORT uint32_t LIBDEFLATEAPI libdeflate_crc32(uint32_t crc, const void *buffer, size_t len); +/* ========================================================================== */ +/* Custom memory allocator */ +/* ========================================================================== */ + +/* + * Install a custom memory allocator which libdeflate will use for all memory + * allocations. 'malloc_func' is a function that must behave like malloc(), and + * 'free_func' is a function that must behave like free(). + * + * There must not be any libdeflate_compressor or libdeflate_decompressor + * structures in existence when calling this function. + */ +LIBDEFLATEEXPORT void LIBDEFLATEAPI +libdeflate_set_memory_allocator(void *(*malloc_func)(size_t), + void (*free_func)(void *)); + #ifdef __cplusplus } #endif diff --git a/programs/test_custom_malloc.c b/programs/test_custom_malloc.c new file mode 100644 index 0000000..7e1eced --- /dev/null +++ b/programs/test_custom_malloc.c @@ -0,0 +1,85 @@ +/* + * test_custom_malloc.c + * + * Test libdeflate_set_memory_allocator(). + * Also test injecting allocation failures. + */ + +#include "test_util.h" + +static int malloc_count = 0; +static int free_count = 0; + +static void *do_malloc(size_t size) +{ + malloc_count++; + return malloc(size); +} + +static void *do_fail_malloc(size_t size) +{ + malloc_count++; + return NULL; +} + +static void do_free(void *ptr) +{ + free_count++; + free(ptr); +} + +int +tmain(int argc, tchar *argv[]) +{ + int level; + struct libdeflate_compressor *c; + struct libdeflate_decompressor *d; + + begin_program(argv); + + /* Test that the custom allocator is actually used when requested. */ + + libdeflate_set_memory_allocator(do_malloc, do_free); + ASSERT(malloc_count == 0); + ASSERT(free_count == 0); + + for (level = 1; level <= 12; level++) { + malloc_count = free_count = 0; + c = libdeflate_alloc_compressor(level); + ASSERT(c != NULL); + ASSERT(malloc_count == 1); + ASSERT(free_count == 0); + libdeflate_free_compressor(c); + ASSERT(malloc_count == 1); + ASSERT(free_count == 1); + } + + malloc_count = free_count = 0; + d = libdeflate_alloc_decompressor(); + ASSERT(d != NULL); + ASSERT(malloc_count == 1); + ASSERT(free_count == 0); + libdeflate_free_decompressor(d); + ASSERT(malloc_count == 1); + ASSERT(free_count == 1); + + /* As long as we're here, also test injecting allocation failures. */ + + libdeflate_set_memory_allocator(do_fail_malloc, do_free); + + for (level = 1; level <= 12; level++) { + malloc_count = free_count = 0; + c = libdeflate_alloc_compressor(level); + ASSERT(c == NULL); + ASSERT(malloc_count == 1); + ASSERT(free_count == 0); + } + + malloc_count = free_count = 0; + d = libdeflate_alloc_decompressor(); + ASSERT(d == NULL); + ASSERT(malloc_count == 1); + ASSERT(free_count == 0); + + return 0; +}