OSDN Git Service

Added global (de)initialization functions.
authorLoRd_MuldeR <mulder2@gmx.de>
Thu, 22 Oct 2020 16:01:59 +0000 (18:01 +0200)
committerLoRd_MuldeR <mulder2@gmx.de>
Sat, 20 Mar 2021 20:18:38 +0000 (21:18 +0100)
Makefile
frontend/src/main.c
libslunkcrypt/include/slunkcrypt.h
libslunkcrypt/libSlunkCrypt.vcxproj
libslunkcrypt/libSlunkCrypt.vcxproj.filters
libslunkcrypt/src/junk.c [moved from libslunkcrypt/src/internal.c with 64% similarity]

index 91ae927..a42ba83 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -43,7 +43,7 @@ ifeq ($(STATIC),1)
   LDFLGS += -static
 endif
 
-ifneq ($(filter %-w64-mingw32,$(MACHINE)),)
+ifneq ($(filter %-w64-mingw32 %w64-windows-gnu,$(MACHINE)),)
   LDFLGS += -mconsole -municode
 endif
 
index e61c46d..f0cace7 100644 (file)
 
 #define BUFFER_SIZE 4096U
 
-#define OP_MODE_HELP 0
-#define OP_MODE_VERS 1
-#define OP_MODE_ENCR 2
-#define OP_MODE_DECR 3
-#define OP_MODE_TEST 4
+#define SLUNK_MODE_HELP 0
+#define SLUNK_MODE_VERS 1
+#define SLUNK_MODE_ENCR 2
+#define SLUNK_MODE_DECR 3
+#define SLUNK_MODE_TEST 4
 
 static const CHR *const ENVV_PASSWD_NAME = T("SLUNK_PASSPHRASE");
 
@@ -33,23 +33,23 @@ static int parse_mode(const CHR* const command)
 {
        if ((!STRICMP(command, T("-h"))) || (!STRICMP(command, T("/?"))) || (!STRICMP(command, T("--help"))))
        {
-               return OP_MODE_HELP;
+               return SLUNK_MODE_HELP;
        }
        else if ((!STRICMP(command, T("-v"))) || (!STRICMP(command, T("--version"))))
        {
-               return OP_MODE_VERS;
+               return SLUNK_MODE_VERS;
        }
        else if ((!STRICMP(command, T("-e"))) || (!STRICMP(command, T("--encrypt"))))
        {
-               return OP_MODE_ENCR;
+               return SLUNK_MODE_ENCR;
        }
        else if ((!STRICMP(command, T("-d"))) || (!STRICMP(command, T("--decrypt"))))
        {
-               return OP_MODE_DECR;
+               return SLUNK_MODE_DECR;
        }
        else if ((!STRICMP(command, T("-t"))) || (!STRICMP(command, T("--self-test"))))
        {
-               return OP_MODE_TEST;
+               return SLUNK_MODE_TEST;
        }
        else
        {
@@ -573,10 +573,13 @@ int MAIN(const int argc, CHR *const argv[])
        init_terminal();
        setup_signal_handler(SIGINT, sigint_handler);
        int result = EXIT_FAILURE;
+       char *passphrase_buffer = NULL;
 
        FPRINTF(stderr, T("SlunkCrypt Utility (%") T(PRIstr) T("-%") T(PRIstr) T("), by LoRd_MuldeR <MuldeR2@GMX.de>\n"), OS_TYPE, CPU_ARCH);
        FPRINTF(stderr, T("Using libSlunkCrypt v%u.%u.%u [%") T(PRIstr) T("]\n\n"), SLUNKCRYPT_VERSION_MAJOR, SLUNKCRYPT_VERSION_MINOR, SLUNKCRYPT_VERSION_PATCH, SLUNKCRYPT_BUILD);
 
+       slunkcrypt_startup();
+
        /* ----------------------------------------------------- */
        /* Parse arguments                                       */
        /* ----------------------------------------------------- */
@@ -584,24 +587,26 @@ int MAIN(const int argc, CHR *const argv[])
        if (argc < 2)
        {
                FPRINTF(stderr, T("Error: Nothing to do. Please type '%") T(PRISTR) T(" --help' for details!\n\n"), get_file_name(argv[0U]));
-               return EXIT_FAILURE;
+               goto clean_up;
        }
 
        const int mode = parse_mode(argv[1U]);
        switch (mode)
        {
-       case OP_MODE_HELP:
+       case SLUNK_MODE_HELP:
                print_manpage(get_file_name(argv[0U]));
-       case OP_MODE_VERS:
-               return EXIT_SUCCESS;
-       case OP_MODE_TEST:
-               return run_self_test();
+       case SLUNK_MODE_VERS:
+               result = EXIT_SUCCESS;
+               goto clean_up;
+       case SLUNK_MODE_TEST:
+               result = run_self_test();
+               goto clean_up;
        }
 
        if (argc < 4)
        {
                FPRINTF(stderr, T("Error: Required argument is missing. Please type '%") T(PRISTR) T(" --help' for details!\n\n"), get_file_name(argv[0U]));
-               return EXIT_FAILURE;
+               goto clean_up;
        }
 
        const CHR* const passphrase = (argc > 4) ? argv[2U] : GETENV(ENVV_PASSWD_NAME);
@@ -610,24 +615,24 @@ int MAIN(const int argc, CHR *const argv[])
        if ((!passphrase) || (!passphrase[0U]) || (((passphrase[0U] == T('@')) || (passphrase[0U] == T(':'))) && (!passphrase[1U])))
        {
                FPUTS(T("Error: The specified passphrase must not be empty!\n\n"), stderr);
-               return EXIT_FAILURE;
+               goto clean_up;
        }
 
        if ((!input_file[0U]) || (!output_file[0U]))
        {
                FPUTS(T("Error: The input file and/or output file must not be empty!\n\n"), stderr);
-               return EXIT_FAILURE;
+               goto clean_up;
        }
 
        /* ----------------------------------------------------- */
        /* Initialize passphrase                                 */
        /* ----------------------------------------------------- */
 
-       char *const passphrase_buffer = (passphrase[0U] == T('@')) ? read_passphrase(passphrase + 1U) : CHR_to_utf8((passphrase[0U] == T(':')) ? (passphrase + 1U) : passphrase);
+       passphrase_buffer = (passphrase[0U] == T('@')) ? read_passphrase(passphrase + 1U) : CHR_to_utf8((passphrase[0U] == T(':')) ? (passphrase + 1U) : passphrase);
        if (!passphrase_buffer)
        {
                FPUTS(T("Error: Failed to read the passphrase!\n\n"), stderr);
-               return EXIT_FAILURE;
+               goto clean_up;
        }
 
        slunkcrypt_bzero((CHR*)passphrase, STRLEN(passphrase) * sizeof(CHR));
@@ -636,12 +641,12 @@ int MAIN(const int argc, CHR *const argv[])
        if (passphrase_len < SLUNKCRYPT_PWDLEN_MIN)
        {
                FPRINTF(stderr, T("Error: Passphrase must be at least %") T(PRIu64) T(" characters in length!\n\n"), (uint64_t)SLUNKCRYPT_PWDLEN_MIN);
-               goto exiting;
+               goto clean_up;
        }
        else if (passphrase_len > SLUNKCRYPT_PWDLEN_MAX)
        {
                FPRINTF(stderr, T("Error: Passphrase must be at most %") T(PRIu64) T(" characters in length!\n\n"), (uint64_t)SLUNKCRYPT_PWDLEN_MAX);
-               goto exiting;
+               goto clean_up;
        }
 
        if (passphrase_len < 12U)
@@ -661,10 +666,10 @@ int MAIN(const int argc, CHR *const argv[])
 
        switch (mode)
        {
-       case OP_MODE_ENCR:
+       case SLUNK_MODE_ENCR:
                result = encrypt(passphrase_buffer, input_file, output_file);
                break;
-       case OP_MODE_DECR:
+       case SLUNK_MODE_DECR:
                result = decrypt(passphrase_buffer, input_file, output_file);
                break;
        default:
@@ -683,7 +688,7 @@ int MAIN(const int argc, CHR *const argv[])
        /* Final clean-up                                        */
        /* ----------------------------------------------------- */
 
-exiting:
+clean_up:
        
        if (passphrase_buffer)
        {
@@ -691,6 +696,7 @@ exiting:
                free(passphrase_buffer);
        }
 
+       slunkcrypt_cleanup();
        return result;
 }
 
index 93b0235..12a506a 100644 (file)
@@ -46,6 +46,12 @@ static const size_t SLUNKCRYPT_PWDLEN_MIN =   5U;
 static const size_t SLUNKCRYPT_PWDLEN_MAX = 512U;
 
 /*
+ * Global (de)initialization routines
+ */
+void slunkcrypt_startup(void);
+void slunkcrypt_cleanup(void);
+
+/*
  * Seed generator
  */
 int slunkcrypt_generate_seed(uint64_t* const seed);
index ae6970b..8a75d27 100644 (file)
@@ -19,7 +19,7 @@
     </ProjectConfiguration>
   </ItemGroup>
   <ItemGroup>
-    <ClCompile Include="src\internal.c" />
+    <ClCompile Include="src\junk.c" />
     <ClCompile Include="src\slunkcrypt.c" />
   </ItemGroup>
   <ItemGroup>
index c1afe5c..025ea26 100644 (file)
     </Filter>
   </ItemGroup>
   <ItemGroup>
-    <ClCompile Include="src\internal.c">
+    <ClCompile Include="src\slunkcrypt.c">
       <Filter>Source Files</Filter>
     </ClCompile>
-    <ClCompile Include="src\slunkcrypt.c">
+    <ClCompile Include="src\junk.c">
       <Filter>Source Files</Filter>
     </ClCompile>
   </ItemGroup>
similarity index 64%
rename from libslunkcrypt/src/internal.c
rename to libslunkcrypt/src/junk.c
index 1f64852..37c8703 100644 (file)
 #      endif
 #      if HAVE_GETRANDOM
 #              include <sys/random.h>
-#      else
-#              include <pthread.h>
 #      endif
 #endif
 
 // ==========================================================================
-// Initialization
+// (De)Initialization
 // ==========================================================================
 
 #if defined(_WIN32)
 typedef BOOLEAN(WINAPI *genrandom_t)(void*, ULONG);
-static genrandom_t win32_init_random(void)
+static HMODULE s_advapi32 = NULL;
+static genrandom_t s_genrandom = NULL;
+#elif !HAVE_GETRANDOM
+static const char *const DEV_RANDOM[] = { "/dev/urandom", "/dev/arandom", "/dev/random", NULL };
+static int s_random_fd = -1;
+#endif
+
+void slunkcrypt_startup(void)
 {
-       static volatile LONG s_random_init = 0L;
-       static HMODULE s_advapi32 = NULL;
-       static genrandom_t s_genrandom = NULL;
-       LONG state;
-       while ((state = InterlockedCompareExchange(&s_random_init, -1L, 0L)) != 0L)
+#if defined(_WIN32)
+       if (s_advapi32 || (s_advapi32 = LoadLibraryW(L"advapi32.dll")))
        {
-               if (state > 0L)
-               {
-                       return s_genrandom;
-               }
-               Sleep(0U);
+               s_genrandom = (genrandom_t)GetProcAddress(s_advapi32, "SystemFunction036");
        }
-       if (s_advapi32 || (s_advapi32 = LoadLibraryW(L"advapi32.dll")))
+#elif !HAVE_GETRANDOM
+       for (size_t i = 0U; (s_random_fd < 0) && DEV_RANDOM[i]; ++i)
        {
-               if ((s_genrandom = (genrandom_t)GetProcAddress(s_advapi32, "SystemFunction036")))
-               {
-                       InterlockedExchange(&s_random_init, 1L);
-                       return s_genrandom;
-               }
+               s_random_fd = open(DEV_RANDOM[i], O_RDONLY);
        }
-       InterlockedExchange(&s_random_init, 0L);
-       return NULL;
+#endif
 }
-#elif !HAVE_GETRANDOM
-static int unix_init_random(void)
+
+void slunkcrypt_cleanup(void)
 {
-       static pthread_mutex_t s_mutex = PTHREAD_MUTEX_INITIALIZER;
-       static int s_random_fd = -1;
-       static const char *const DEV_RANDOM[] = { "/dev/urandom", "/dev/arandom", "/dev/random", NULL };
-       if (pthread_mutex_lock(&s_mutex) != 0)
+#if defined(_WIN32)
+       s_genrandom = NULL;
+       if (s_advapi32)
        {
-               return -1;
+               FreeLibrary(s_advapi32);
+               s_advapi32 = NULL;
        }
-       if (s_random_fd < 0)
+#elif !HAVE_GETRANDOM
+       if (s_random_fd >= 0)
        {
-               for (size_t i = 0U; DEV_RANDOM[i]; ++i)
-               {
-                       if ((s_random_fd = open(DEV_RANDOM[i], O_RDONLY)) >= 0)
-                       {
-                               break;
-                       }
-               }
+               close(s_random_fd);
+               s_random_fd = -1;
        }
-       pthread_mutex_unlock(&s_mutex);
-       return s_random_fd;
-}
 #endif
+}
 
 // ==========================================================================
-// Public functions
+// Auxiliary functions
 // ==========================================================================
 
 int slunkcrypt_random_bytes(uint8_t* const buffer, const size_t length)
@@ -104,10 +92,9 @@ int slunkcrypt_random_bytes(uint8_t* const buffer, const size_t length)
 #if defined(_WIN32)
        if ((length <= ((size_t)ULONG_MAX)))
        {
-               const genrandom_t genrandom = win32_init_random();
-               if (genrandom)
+               if (s_genrandom)
                {
-                       return genrandom(buffer, (ULONG)length) ? 0 : (-1);
+                       return s_genrandom(buffer, (ULONG)length) ? 0 : (-1);
                }
        }
        return -1;
@@ -118,10 +105,9 @@ int slunkcrypt_random_bytes(uint8_t* const buffer, const size_t length)
        }
        return -1;
 #else
-       const int fd = unix_init_random();
-       if (fd >= 0)
+       if (s_random_fd >= 0)
        {
-               if (read(fd, buffer, length) >= length)
+               if (read(s_random_fd, buffer, length) >= length)
                {
                        return 0;
                }