OSDN Git Service

Some code refactoring in the thread pool module.
authorLoRd_MuldeR <mulder2@gmx.de>
Thu, 24 Mar 2022 20:25:32 +0000 (21:25 +0100)
committerLoRd_MuldeR <mulder2@gmx.de>
Thu, 24 Mar 2022 20:59:43 +0000 (21:59 +0100)
frontend/src/crypt.c
libslunkcrypt/src/slunkcrypt.c
libslunkcrypt/src/thread.c
libslunkcrypt/src/thread.h

index b47bec6..3f10ae3 100644 (file)
@@ -28,7 +28,7 @@
 
 static const uint64_t MAGIC_NUMBER = 0x243F6A8885A308D3ull;
 
-#define BUFFER_SIZE 32768U
+#define BUFFER_SIZE 65536U
 
 // ==========================================================================
 // Auxiliary functions
@@ -70,6 +70,13 @@ int encrypt(const char *const passphrase, const CHR *const input_path, const CHR
        FILE *file_in = NULL, *file_out = NULL;
        int result = EXIT_FAILURE, status;
 
+       uint8_t *buffer = malloc(BUFFER_SIZE * sizeof(uint8_t));
+       if (!buffer)
+       {
+               FPUTS(T("Error: Failed to allocate the I/O buffer!\n\n"), stderr);
+               goto clean_up;
+       }
+
        if (open_files(&file_in, &file_out, input_path, output_path) != EXIT_SUCCESS)
        {
                goto clean_up;
@@ -111,10 +118,7 @@ int encrypt(const char *const passphrase, const CHR *const input_path, const CHR
                goto clean_up;
        }
 
-       unsigned refresh_cycles = 0U;
-       uint64_t bytes_read = 0U, clk_update = clock_read();
-       uint8_t buffer[BUFFER_SIZE];
-
+       uint64_t bytes_read = 0U, clk_now, clk_update = clk_now = clock_read();
        const uint64_t update_interval = (uint64_t)(clock_freq() * 1.0625);
 
        blake2s_t blake2s_state;
@@ -147,15 +151,11 @@ int encrypt(const char *const passphrase, const CHR *const input_path, const CHR
                {
                        break; /*EOF*/
                }
-               if (!(++refresh_cycles & 0x3))
+               if (((clk_now = clock_read()) < clk_update) || (clk_now - clk_update > update_interval))
                {
-                       const uint64_t clk_now = clock_read();
-                       if ((clk_now < clk_update) || (clk_now - clk_update > update_interval))
-                       {
-                               FPRINTF(stderr, T("\b\b\b\b\b\b\b%5.1f%% "), (bytes_read / ((double)file_size)) * 100.0);
-                               fflush(stderr);
-                               clk_update = clk_now;
-                       }
+                       FPRINTF(stderr, T("\b\b\b\b\b\b\b%5.1f%% "), (bytes_read / ((double)file_size)) * 100.0);
+                       fflush(stderr);
+                       clk_update = clk_now;
                }
        }
 
@@ -238,7 +238,12 @@ clean_up:
                fclose(file_in);
        }
 
-       slunkcrypt_bzero(buffer, BUFFER_SIZE);
+       if (buffer)
+       {
+               slunkcrypt_bzero(buffer, BUFFER_SIZE * sizeof(uint8_t));
+               free(buffer);
+       }
+
        slunkcrypt_bzero(checksum_buffer, sizeof(uint64_t));
        slunkcrypt_bzero(&blake2s_state, sizeof(blake2s_t));
        slunkcrypt_bzero(&nonce, sizeof(uint64_t));
@@ -257,6 +262,13 @@ int decrypt(const char *const passphrase, const CHR *const input_path, const CHR
        FILE *file_in = NULL, *file_out = NULL;
        int result = EXIT_FAILURE, status;
 
+       uint8_t *buffer = malloc(BUFFER_SIZE * sizeof(uint8_t));
+       if (!buffer)
+       {
+               FPUTS(T("Error: Failed to allocate the I/O buffer!\n\n"), stderr);
+               goto clean_up;
+       }
+
        if (open_files(&file_in, &file_out, input_path, output_path) != EXIT_SUCCESS)
        {
                goto clean_up;
@@ -296,10 +308,7 @@ int decrypt(const char *const passphrase, const CHR *const input_path, const CHR
                goto clean_up;
        }
 
-       unsigned refresh_cycles = 0U;
-       uint64_t bytes_read = sizeof(uint64_t), clk_update = clock_read();
-       uint8_t buffer[BUFFER_SIZE];
-
+       uint64_t bytes_read = sizeof(uint64_t), clk_now, clk_update = clk_now = clock_read();
        const uint64_t update_interval = (uint64_t)(clock_freq() * 1.0625);
        const uint64_t read_limit = round_down(file_size, sizeof(uint64_t)) - (2U * sizeof(uint64_t));
 
@@ -333,15 +342,11 @@ int decrypt(const char *const passphrase, const CHR *const input_path, const CHR
                {
                        break; /*EOF*/
                }
-               if (!(++refresh_cycles & 0x3))
+               if (((clk_now = clock_read()) < clk_update) || (clk_now - clk_update > update_interval))
                {
-                       const uint64_t clk_now = clock_read();
-                       if ((clk_now < clk_update) || (clk_now - clk_update > update_interval))
-                       {
-                               FPRINTF(stderr, T("\b\b\b\b\b\b\b%5.1f%% "), (bytes_read / ((double)read_limit)) * 100.0);
-                               fflush(stderr);
-                               clk_update = clk_now;
-                       }
+                       FPRINTF(stderr, T("\b\b\b\b\b\b\b%5.1f%% "), (bytes_read / ((double)read_limit)) * 100.0);
+                       fflush(stderr);
+                       clk_update = clk_now;
                }
        }
 
@@ -436,7 +441,12 @@ clean_up:
                fclose(file_in);
        }
 
-       slunkcrypt_bzero(buffer, BUFFER_SIZE);
+       if (buffer)
+       {
+               slunkcrypt_bzero(buffer, BUFFER_SIZE * sizeof(uint8_t));
+               free(buffer);
+       }
+
        slunkcrypt_bzero(checksum_buffer, sizeof(uint64_t));
        slunkcrypt_bzero(&blake2s_state, sizeof(blake2s_t));
        slunkcrypt_bzero(&nonce, sizeof(uint64_t));
index d0785ad..635ffa5 100644 (file)
@@ -22,7 +22,7 @@ const char *const SLUNKCRYPT_BUILD = __DATE__ ", " __TIME__;
 
 /* Utilities */
 #define BOOLIFY(X) (!!(X))
-#define THREAD_COUNT(X) (((X)->thread_pool != THRDPL_NULL) ? slunkcrypt_thrdpl_count((X)->thread_pool) : 1U)
+#define THREAD_COUNT(X) (((X)->thread_pool) ? slunkcrypt_thrdpl_count((X)->thread_pool) : 1U)
 
 // ==========================================================================
 // Data structures
@@ -53,7 +53,7 @@ crypt_data_t;
 
 typedef struct
 {
-       thrdpl_t thread_pool;
+       thrdpl_t *thread_pool;
        crypt_data_t data;
 }
 crypt_state_t;
@@ -395,7 +395,7 @@ void slunkcrypt_free(const slunkcrypt_t context)
        crypt_state_t *const state = (crypt_state_t*) context;
        if (state)
        {
-               if (state->thread_pool != THRDPL_NULL)
+               if (state->thread_pool)
                {
                        slunkcrypt_thrdpl_destroy(state->thread_pool);
                }
index daf2332..0351101 100644 (file)
@@ -5,11 +5,11 @@
 
 /* Internal */
 #include "thread.h"
-#include "slunkcrypt.h"
 #include "compiler.h"
 
 /* CRT */
 #include <stdlib.h>
+#include <string.h>
 
 /* PThread */
 #if defined(_MSC_VER) && !defined(_DLL)
@@ -23,9 +23,9 @@
 #endif
 
 /* States */
-#define THRD_STATE_IDLE 0
-#define THRD_STATE_WORK 1
-#define THRD_STATE_EXIT 2
+#define TSTATE_IDLE 0U
+#define TSTATE_WORK 1U
+#define TSTATE_EXIT 2U
 
 // ==========================================================================
 // Data types
@@ -42,21 +42,27 @@ thrdpl_task_t;
 
 typedef struct
 {
-       const size_t *count;
-       int state;
+       size_t thread_count, pending;
        pthread_mutex_t mutex;
-       pthread_cond_t cond;
+       pthread_cond_t cond_pending;
+}
+thrdpl_shared_t;
+
+typedef struct
+{
+       thrdpl_shared_t *shared;
+       size_t state;
+       pthread_cond_t cond_state;
        pthread_t thread;
        thrdpl_task_t task;
 }
 thrdpl_thread_t;
 
-typedef struct
+struct thrdpl_data_t
 {
-       size_t thread_count;
+       thrdpl_shared_t shared;
        thrdpl_thread_t thread_data[MAX_THREADS];
-}
-thrdpl_data_t;
+};
 
 // ==========================================================================
 // Utilities
@@ -85,6 +91,15 @@ while(0)
 } \
 while(0)
 
+#define PTHRD_COND_SIGNAL(X) do \
+{ \
+       if (pthread_cond_signal((X)) != 0) \
+       { \
+               abort(); \
+       } \
+} \
+while(0)
+
 #define PTHRD_COND_BROADCAST(X) do \
 { \
        if (pthread_cond_broadcast((X)) != 0) \
@@ -105,13 +120,10 @@ while(0)
 
 #define CHECK_IF_CANCELLED() do \
 { \
-       if (data->state == THRD_STATE_EXIT) \
+       if (data->state == TSTATE_EXIT) \
        { \
-               if (pthread_mutex_unlock(&data->mutex) != 0) \
-               { \
-                       abort(); \
-               } \
-               return NULL; /* cancelled */ \
+               PTHRD_MUTEX_UNLOCK(&shared->mutex); \
+               return NULL; \
        } \
 } \
 while(0)
@@ -123,29 +135,37 @@ while(0)
 static void *worker_thread_main(void *const arg)
 {
        thrdpl_thread_t *const data = (thrdpl_thread_t*) arg;
+       thrdpl_shared_t *const shared = (thrdpl_shared_t*) data->shared;
+       
        thrdpl_task_t *task;
 
        for (;;)
        {
-               PTHRD_MUTEX_LOCK(&data->mutex);
+               PTHRD_MUTEX_LOCK(&shared->mutex);
                CHECK_IF_CANCELLED();
 
-               while (data->state != THRD_STATE_WORK)
+               while (data->state != TSTATE_WORK)
                {
-                       PTHRD_COND_WAIT(&data->cond, &data->mutex);
+                       PTHRD_COND_WAIT(&data->cond_state, &shared->mutex);
                        CHECK_IF_CANCELLED();
                }
 
                task = &data->task;
-               PTHRD_MUTEX_UNLOCK(&data->mutex);
+               PTHRD_MUTEX_UNLOCK(&shared->mutex);
 
-               task->worker(*data->count, task->context, task->buffer, task->length);
+               task->worker(shared->thread_count, task->context, task->buffer, task->length);
 
-               PTHRD_MUTEX_LOCK(&data->mutex);
+               PTHRD_MUTEX_LOCK(&shared->mutex);
                CHECK_IF_CANCELLED();
-               data->state = THRD_STATE_IDLE;
-               PTHRD_COND_BROADCAST(&data->cond);
-               PTHRD_MUTEX_UNLOCK(&data->mutex);
+
+               data->state = TSTATE_IDLE;
+               if (!(--shared->pending))
+               {
+                       PTHRD_COND_BROADCAST(&shared->cond_pending);
+               }
+
+               PTHRD_MUTEX_UNLOCK(&shared->mutex);
+               PTHRD_COND_SIGNAL(&data->cond_state);
        }
 }
 
@@ -175,42 +195,34 @@ static size_t detect_cpu_count(void)
 // Manage threads
 // ==========================================================================
 
-static int create_worker_thread(thrdpl_thread_t *const thread_data, const size_t *const count)
+static int create_worker(thrdpl_shared_t *const shared, thrdpl_thread_t *const thread_data)
 {
-       thread_data->count = count;
-       thread_data->state = THRD_STATE_IDLE;
-
-       if (pthread_mutex_init(&thread_data->mutex, NULL) != 0)
-       {
-               return -1;
-       }
+       thread_data->state = TSTATE_IDLE;
+       thread_data->shared = shared;
 
-       if (pthread_cond_init(&thread_data->cond, NULL) != 0)
+       if (pthread_cond_init(&thread_data->cond_state, NULL) != 0)
        {
-               pthread_mutex_destroy(&thread_data->mutex);
                return -1;
        }
 
        if (pthread_create(&thread_data->thread, NULL, worker_thread_main, thread_data) != 0)
        {
-               pthread_cond_destroy(&thread_data->cond);
-               pthread_mutex_destroy(&thread_data->mutex);
+               pthread_cond_destroy(&thread_data->cond_state);
                return -1;
        }
 
        return 0;
 }
 
-static int destroy_worker_thread(thrdpl_thread_t *const thread_data)
+static int destroy_worker(thrdpl_thread_t *const thread_data)
 {
-       PTHRD_MUTEX_LOCK(&thread_data->mutex);
-       thread_data->state = THRD_STATE_EXIT;
-       PTHRD_COND_BROADCAST(&thread_data->cond);
-       PTHRD_MUTEX_UNLOCK(&thread_data->mutex);
+       PTHRD_MUTEX_LOCK(&thread_data->shared->mutex);
+       thread_data->state = TSTATE_EXIT;
+       PTHRD_MUTEX_UNLOCK(&thread_data->shared->mutex);
 
+       PTHRD_COND_BROADCAST(&thread_data->cond_state);
        pthread_join(thread_data->thread, NULL);
-       pthread_mutex_destroy(&thread_data->mutex);
-       pthread_cond_destroy(&thread_data->cond);
+       pthread_cond_destroy(&thread_data->cond_state);
 
        return 0;
 }
@@ -219,116 +231,114 @@ static int destroy_worker_thread(thrdpl_thread_t *const thread_data)
 // Thread pool API
 // ==========================================================================
 
-thrdpl_t slunkcrypt_thrdpl_create(const size_t count)
+thrdpl_t *slunkcrypt_thrdpl_create(const size_t count)
 {
        size_t i, j;
-       thrdpl_data_t *pool = NULL;
+       thrdpl_t *thrdpl = NULL;
 
        const size_t cpu_count = bound(1U, (count > 0U) ? count : detect_cpu_count(), MAX_THREADS);
        if (cpu_count < 2U)
        {
-               return THRDPL_NULL;
+               return NULL;
        }
 
-       if (!(pool = (thrdpl_data_t*)malloc(sizeof(thrdpl_data_t))))
+       if (!(thrdpl = (thrdpl_t*)malloc(sizeof(thrdpl_t))))
        {
-               return THRDPL_NULL;
+               return NULL;
        }
 
-       slunkcrypt_bzero(pool, sizeof(thrdpl_data_t));
-       pool->thread_count = cpu_count;
+       memset(thrdpl, 0, sizeof(thrdpl_t));
+       thrdpl->shared.thread_count = cpu_count;
+
+       if (pthread_mutex_init(&thrdpl->shared.mutex, NULL) != 0)
+       {
+               goto failure;
+       }
+       
+       if (pthread_cond_init(&thrdpl->shared.cond_pending, NULL) != 0)
+       {
+               pthread_mutex_destroy(&thrdpl->shared.mutex);
+               goto failure;
+       }
 
-       for (i = 0U; i < pool->thread_count; ++i)
+       for (i = 0U; i < cpu_count; ++i)
        {
-               if (create_worker_thread(&pool->thread_data[i], &pool->thread_count) != 0)
+               if (create_worker(&thrdpl->shared, &thrdpl->thread_data[i]) != 0)
                {
                        for (j = 0U; j < i; ++j)
                        {
-                               destroy_worker_thread(&pool->thread_data[j]);
+                               destroy_worker(&thrdpl->thread_data[j]);
                        }
+                       pthread_cond_destroy(&thrdpl->shared.cond_pending);
+                       pthread_mutex_destroy(&thrdpl->shared.mutex);
                        goto failure;
                }
        }
 
-       return (thrdpl_t)pool;
+       return thrdpl;
 
 failure:
-       free(pool);
-       return (thrdpl_t)NULL;
+       free(thrdpl);
+       return NULL;
 }
 
-size_t slunkcrypt_thrdpl_count(const thrdpl_t thrdpl)
+size_t slunkcrypt_thrdpl_count(const thrdpl_t *const thrdpl)
 {
-       thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
-       return pool->thread_count;
+       return thrdpl->shared.thread_count;
 }
 
-void slunkcrypt_thrdpl_exec(const thrdpl_t thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length)
+void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length)
 {
-       thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
-       thrdpl_thread_t *const thread = &pool->thread_data[index];
+       thrdpl_thread_t *const thread = &thrdpl->thread_data[index];
 
-       PTHRD_MUTEX_LOCK(&thread->mutex);
+       PTHRD_MUTEX_LOCK(&thrdpl->shared.mutex);
 
-       while ((thread->state != THRD_STATE_IDLE) && (thread->state != THRD_STATE_EXIT))
+       while ((thread->state != TSTATE_IDLE) && (thread->state != TSTATE_EXIT))
        {
-               if (pthread_cond_wait(&thread->cond, &thread->mutex) != 0)
-               {
-                       abort();
-               }
+               PTHRD_COND_WAIT(&thread->cond_state, &thrdpl->shared.mutex);
        }
 
-       if (thread->state == THRD_STATE_EXIT)
+       if (thread->state == TSTATE_EXIT)
        {
                abort(); /*this is not supposed to happen!*/
        }
 
+       thread->state = TSTATE_WORK;
        thread->task.worker = worker;
        thread->task.context = context;
        thread->task.buffer = buffer;
        thread->task.length = length;
-       thread->state = THRD_STATE_WORK;
 
-       PTHRD_COND_BROADCAST(&thread->cond);
-       PTHRD_MUTEX_UNLOCK(&thread->mutex);
+       ++thrdpl->shared.pending;
+
+       PTHRD_MUTEX_UNLOCK(&thrdpl->shared.mutex);
+       PTHRD_COND_SIGNAL(&thread->cond_state);
 }
 
-void slunkcrypt_thrdpl_await(const thrdpl_t thrdpl)
+void slunkcrypt_thrdpl_await(thrdpl_t *const thrdpl)
 {
-       size_t i;
-       thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
+       PTHRD_MUTEX_LOCK(&thrdpl->shared.mutex);
 
-       for (i = 0; i < pool->thread_count; ++i)
+       while (thrdpl->shared.pending)
        {
-               if (pthread_mutex_lock(&pool->thread_data[i].mutex) != 0)
-               {
-                       abort();
-               }
-               while ((pool->thread_data[i].state != THRD_STATE_IDLE) && (pool->thread_data[i].state != THRD_STATE_EXIT))
-               {
-                       if (pthread_cond_wait(&pool->thread_data[i].cond, &pool->thread_data[i].mutex) != 0)
-                       {
-                               abort();
-                       }
-               }
-               if (pthread_mutex_unlock(&pool->thread_data[i].mutex) != 0)
-               {
-                       abort();
-               }
+               PTHRD_COND_WAIT(&thrdpl->shared.cond_pending, &thrdpl->shared.mutex);
        }
+
+       PTHRD_MUTEX_UNLOCK(&thrdpl->shared.mutex);
 }
 
-void slunkcrypt_thrdpl_destroy(const thrdpl_t thrdpl)
+void slunkcrypt_thrdpl_destroy(thrdpl_t *const thrdpl)
 {
        size_t i;
-       thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
-       if (pool)
+
+       if (thrdpl)
        {
-               for (i = 0U; i < pool->thread_count; ++i)
+               for (i = 0U; i < thrdpl->shared.thread_count; ++i)
                {
-                       destroy_worker_thread(&pool->thread_data[i]);
+                       destroy_worker(&thrdpl->thread_data[i]);
                }
-               slunkcrypt_bzero(pool, sizeof(thrdpl_data_t));
-               free(pool);
+               pthread_cond_destroy(&thrdpl->shared.cond_pending);
+               pthread_mutex_destroy(&thrdpl->shared.mutex);
+               free(thrdpl);
        }
 }
index 9af2b4c..b21c3a5 100644 (file)
@@ -9,16 +9,15 @@
 #include <stdlib.h>
 #include <stdint.h>
 
-#define MAX_THREADS 16U
-#define THRDPL_NULL ((thrdpl_t)NULL)
+#define MAX_THREADS 32U
 
 typedef void (*thrdpl_worker_t)(const size_t thread_count, void *const context, uint8_t *const buffer, const size_t length);
-typedef uintptr_t thrdpl_t;
+typedef struct thrdpl_data_t thrdpl_t;
 
-thrdpl_t slunkcrypt_thrdpl_create(const size_t count);
-size_t slunkcrypt_thrdpl_count(const thrdpl_t thrdpl);
-void slunkcrypt_thrdpl_exec(const thrdpl_t thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length);
-void slunkcrypt_thrdpl_await(const thrdpl_t thrdpl);
-void slunkcrypt_thrdpl_destroy(const thrdpl_t thrdpl);
+thrdpl_t *slunkcrypt_thrdpl_create(const size_t count);
+size_t slunkcrypt_thrdpl_count(const thrdpl_t *const thrdpl);
+void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length);
+void slunkcrypt_thrdpl_await(thrdpl_t *const thrdpl);
+void slunkcrypt_thrdpl_destroy(thrdpl_t *const thrdpl);
 
 #endif