libzahl

big integer library
git clone git://git.suckless.org/libzahl
Log | Files | Refs | README | LICENSE

commit 3413d878a7b1a37ff362ddfa9141349e73af917a
parent dbf780a6c78f1396f988ef285be2de7c66b6d90b
Author: Mattias Andrée <maandree@kth.se>
Date:   Tue, 26 Apr 2016 22:40:43 +0200

Ensure that failure does not result in memory leak

Signed-off-by: Mattias Andrée <maandree@kth.se>

Diffstat:
Msrc/internals.h | 65++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
Msrc/zmul.c | 16++++++++--------
Msrc/zsetup.c | 14+++++++++++++-
Msrc/zsqr.c | 8++++----
Msrc/zunsetup.c | 2++
Mzahl-internals.h | 10++++++++++
Mzahl.h | 11+----------
7 files changed, 102 insertions(+), 24 deletions(-)

diff --git a/src/internals.h b/src/internals.h @@ -89,11 +89,13 @@ extern int libzahl_error; extern zahl_char_t **libzahl_pool[sizeof(size_t) * 8]; extern size_t libzahl_pool_n[sizeof(size_t) * 8]; extern size_t libzahl_pool_alloc[sizeof(size_t) * 8]; +extern struct zahl **libzahl_temp_stack; +extern struct zahl **libzahl_temp_stack_head; +extern struct zahl **libzahl_temp_stack_end; #define likely(expr) ZAHL_LIKELY(expr) #define unlikely(expr) ZAHL_UNLIKELY(expr) -#define libzahl_failure(error) (libzahl_error = (error), longjmp(libzahl_jmp_buf, 1)) #define SET_SIGNUM(a, signum) ZAHL_SET_SIGNUM(a, signum) #define SET(a, b) ZAHL_SET(a, b) #define ENSURE_SIZE(a, n) do { if ((a)->alloced < (n)) libzahl_realloc(a, (n)); } while (0) @@ -117,6 +119,16 @@ void libzahl_realloc(z_t a, size_t need); void zmul_impl(z_t a, z_t b, z_t c); void zsqr_impl(z_t a, z_t b); +static void +libzahl_failure(int error) +{ + libzahl_error = (error); + if (libzahl_temp_stack) + while (libzahl_temp_stack_head != libzahl_temp_stack) + zfree(*--libzahl_temp_stack_head); + longjmp(libzahl_jmp_buf, 1); +} + static inline void zmemcpy(zahl_char_t *restrict d, const zahl_char_t *restrict s, register size_t n) { @@ -272,3 +284,54 @@ zsplit_unsigned_fast_small_auto(z_t high, z_t low, z_t a, size_t n) if (unlikely(!low->chars[0])) low->sign = 0; } + +/* Calls to these functions must be called in stack-order + * For example, + * + * zinit_temp(a); + * zinit_temp(b); + * zfree_temp(b); + * zinit_temp(c); + * zfree_temp(c); + * zfree_temp(a); + * + * + * And not (swap the two last lines) + * + * zinit_temp(a); + * zinit_temp(b); + * zfree_temp(b); + * zinit_temp(c); + * zfree_temp(a); + * zfree_temp(c); + * + * { */ + +static inline void +zinit_temp(z_t a) +{ + zinit(a); + if (unlikely(libzahl_temp_stack_head == libzahl_temp_stack_end)) { + size_t n = (size_t)(libzahl_temp_stack_end - libzahl_temp_stack); + void* old = libzahl_temp_stack; + libzahl_temp_stack = realloc(old, 2 * n * sizeof(*libzahl_temp_stack)); + if (unlikely(!libzahl_temp_stack)) { + libzahl_temp_stack = old; + if (!errno) /* sigh... */ + errno = ENOMEM; + libzahl_failure(errno); + } + libzahl_temp_stack_head = libzahl_temp_stack + n; + libzahl_temp_stack_end = libzahl_temp_stack_head + n; + } + *libzahl_temp_stack_head++ = a; +} + +static inline void +zfree_temp(z_t a) +{ + zfree(a); + libzahl_temp_stack_head--; +} + +/* } */ diff --git a/src/zmul.c b/src/zmul.c @@ -48,10 +48,10 @@ zmul_impl(z_t a, z_t b, z_t c) m = MAX(m, m2); m2 = m >> 1; - zinit(b_high); - zinit(b_low); - zinit(c_high); - zinit(c_low); + zinit_temp(b_high); + zinit_temp(b_low); + zinit_temp(c_high); + zinit_temp(c_low); zsplit(b_high, b_low, b, m2); zsplit(c_high, c_low, c, m2); @@ -73,10 +73,10 @@ zmul_impl(z_t a, z_t b, z_t c) zadd_unsigned_assign(a, z2); - zfree(b_high); - zfree(b_low); - zfree(c_high); - zfree(c_low); + zfree_temp(c_low); + zfree_temp(c_high); + zfree_temp(b_low); + zfree_temp(b_high); } void diff --git a/src/zsetup.c b/src/zsetup.c @@ -15,6 +15,9 @@ int libzahl_error; zahl_char_t **libzahl_pool[sizeof(size_t) * 8]; size_t libzahl_pool_n[sizeof(size_t) * 8]; size_t libzahl_pool_alloc[sizeof(size_t) * 8]; +struct zahl **libzahl_temp_stack; +struct zahl **libzahl_temp_stack_head; +struct zahl **libzahl_temp_stack_end; void @@ -23,7 +26,7 @@ zsetup(jmp_buf env) size_t i; *libzahl_jmp_buf = *env; - if (!libzahl_set_up) { + if (likely(!libzahl_set_up)) { libzahl_set_up = 1; memset(libzahl_pool, 0, sizeof(libzahl_pool)); @@ -40,5 +43,14 @@ zsetup(jmp_buf env) #undef X for (i = BITS_PER_CHAR; i--;) zinit(libzahl_tmp_divmod_ds[i]); + + libzahl_temp_stack = malloc(256 * sizeof(*libzahl_temp_stack)); + if (unlikely(!libzahl_temp_stack)) { + if (!errno) /* sigh... */ + errno = ENOMEM; + libzahl_failure(errno); + } + libzahl_temp_stack_head = libzahl_temp_stack; + libzahl_temp_stack_end = libzahl_temp_stack + 256; } } diff --git a/src/zsqr.c b/src/zsqr.c @@ -50,8 +50,8 @@ zsqr_impl(z_t a, z_t b) zsqr_impl(z2, high); zlsh(a, z2, bits << 1); } else { - zinit(z0); - zinit(z1); + zinit_temp(z0); + zinit_temp(z1); zsqr_impl(z0, low); @@ -64,8 +64,8 @@ zsqr_impl(z_t a, z_t b) zadd_unsigned_assign(a, z1); zadd_unsigned_assign(a, z0); - zfree(z0); - zfree(z1); + zfree_temp(z1); + zfree_temp(z0); } } diff --git a/src/zunsetup.c b/src/zunsetup.c @@ -24,5 +24,7 @@ zunsetup(void) free(libzahl_pool[i][libzahl_pool_n[i]]); free(libzahl_pool[i]); } + + free(libzahl_temp_stack); } } diff --git a/zahl-internals.h b/zahl-internals.h @@ -62,3 +62,13 @@ typedef uint64_t zahl_char_t; + +struct zahl { + int sign; +#if INT_MAX != LONG_MAX + int padding__; +#endif + size_t used; + size_t alloced; + zahl_char_t *chars; +}; diff --git a/zahl.h b/zahl.h @@ -16,16 +16,7 @@ -/* This structure should be considered opaque. */ -typedef struct { - int sign; -#if INT_MAX != LONG_MAX - int padding__; -#endif - size_t used; - size_t alloced; - zahl_char_t *chars; -} z_t[1]; +typedef struct zahl z_t[1];