aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@gmail.com>2015-08-22 17:46:29 -0400
committerMarat Dukhan <maratek@gmail.com>2015-08-22 17:46:29 -0400
commit0a312191e15b30e8e50a89077edc92bc2ee5682d (patch)
tree673d7b763a6c95ac4d29499b1d2f18da9778a441
parentfaf6ff30a6ac3fc9be846cf3a89942ebec78aeeb (diff)
downloadpthreadpool-0a312191e15b30e8e50a89077edc92bc2ee5682d.tar.gz
Initial thread pool implementation
-rw-r--r--.gitignore17
-rwxr-xr-xconfigure.py204
-rw-r--r--include/pthreadpool.h73
-rw-r--r--src/pthreadpool.c296
-rw-r--r--test/pthreadpool.cc111
5 files changed, 701 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..f42c5bc
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,17 @@
+# Ninja files
+.ninja_deps
+.ninja_log
+build.ninja
+
+# Build objects and artifacts
+build/*
+artifacts/*
+
+# System files
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
diff --git a/configure.py b/configure.py
new file mode 100755
index 0000000..72db0c4
--- /dev/null
+++ b/configure.py
@@ -0,0 +1,204 @@
+#!/usr/bin/env python
+
+import os
+import sys
+import glob
+import argparse
+import ninja_syntax
+
+
+class Configuration:
+
+ def __init__(self, options, ninja_build_file=os.path.join(os.path.dirname(os.path.abspath(__file__)), "build.ninja")):
+ self.output = open(ninja_build_file, "w")
+ self.writer = ninja_syntax.Writer(self.output)
+ self.source_dir = None
+ self.build_dir = None
+ self.artifact_dir = None
+ self.prefix_dir = options.prefix
+ self.include_dirs = []
+ self.object_ext = ".bc"
+
+ # Variables
+ self.writer.variable("nacl_sdk_dir", options.nacl_sdk)
+ self._set_pnacl_vars()
+ self.writer.variable("cflags", "-std=gnu11")
+ self.writer.variable("cxxflags", "-std=gnu++11")
+ self.writer.variable("optflags", "-O3")
+
+ # Rules
+ self.writer.rule("cc", "$pnacl_cc -o $out -c $in -MMD -MF $out.d $optflags $cflags $includes",
+ deps="gcc", depfile="$out.d",
+ description="CC[PNaCl] $descpath")
+ self.writer.rule("cxx", "$pnacl_cxx -o $out -c $in -MMD -MF $out.d $optflags $cxxflags $includes",
+ deps="gcc", depfile="$out.d",
+ description="CXX[PNaCl] $descpath")
+ self.writer.rule("ccld", "$pnacl_cc -o $out $in $libs $libdirs $ldflags",
+ description="CCLD[PNaCl] $descpath")
+ self.writer.rule("cxxld", "$pnacl_cxx -o $out $in $libs $libdirs $ldflags",
+ description="CXXLD[PNaCl] $descpath")
+ self.writer.rule("ar", "$pnacl_ar rcs $out $in",
+ description="AR[PNaCl] $descpath")
+ self.writer.rule("finalize", "$pnacl_finalize $finflags -o $out $in",
+ description="FINALIZE[PNaCl] $descpath")
+ self.writer.rule("translate", "$pnacl_translate -arch $arch -o $out $in",
+ description="TRANSLATE[PNaCl] $descpath")
+ self.writer.rule("run", "$pnacl_sel_ldr $in",
+ description="RUN[PNaCl] $descpath", pool="console")
+ self.writer.rule("install", "install -m $mode $in $out",
+ description="INSTALL $out")
+
+
+ def _set_pnacl_vars(self):
+ if sys.platform == "win32":
+ self.writer.variable("pnacl_toolchain_dir", "$nacl_sdk_dir/toolchain/win_pnacl")
+ self.writer.variable("pnacl_cc", "$pnacl_toolchain_dir/bin/pnacl-clang.bat")
+ self.writer.variable("pnacl_cxx", "$pnacl_toolchain_dir/bin/pnacl-clang++.bat")
+ self.writer.variable("pnacl_ar", "$pnacl_toolchain_dir/bin/pnacl-ar.bat")
+ self.writer.variable("pnacl_finalize", "$pnacl_toolchain_dir/bin/pnacl-finalize.bat")
+ self.writer.variable("pnacl_translate", "$pnacl_toolchain_dir/bin/pnacl-translate.bat")
+ elif sys.platform == "linux2" or sys.platform == "darwin":
+ if sys.platform == "linux2":
+ self.writer.variable("pnacl_toolchain_dir", "$nacl_sdk_dir/toolchain/linux_pnacl")
+ else:
+ self.writer.variable("pnacl_toolchain_dir", "$nacl_sdk_dir/toolchain/mac_pnacl")
+ self.writer.variable("pnacl_cc", "$pnacl_toolchain_dir/bin/pnacl-clang")
+ self.writer.variable("pnacl_cxx", "$pnacl_toolchain_dir/bin/pnacl-clang++")
+ self.writer.variable("pnacl_ar", "$pnacl_toolchain_dir/bin/pnacl-ar")
+ self.writer.variable("pnacl_finalize", "$pnacl_toolchain_dir/bin/pnacl-finalize")
+ self.writer.variable("pnacl_translate", "$pnacl_toolchain_dir/bin/pnacl-translate")
+ else:
+ raise OSError("Unsupported platform: " + sys.platform)
+ self.writer.variable("pnacl_sel_ldr", "$nacl_sdk_dir/tools/sel_ldr.py")
+
+
+ def _compile(self, rule, source_file, object_file):
+ if not os.path.isabs(source_file):
+ source_file = os.path.join(self.source_dir, source_file)
+ if object_file is None:
+ object_file = os.path.join(self.build_dir, os.path.relpath(source_file, self.source_dir)) + self.object_ext
+ variables = {
+ "descpath": os.path.relpath(source_file, self.source_dir)
+ }
+ if self.include_dirs:
+ variables["includes"] = " ".join(["-I" + i for i in self.include_dirs])
+ self.writer.build(object_file, rule, source_file, variables=variables)
+ return object_file
+
+
+ def cc(self, source_file, object_file=None):
+ return self._compile("cc", source_file, object_file)
+
+
+ def cxx(self, source_file, object_file=None):
+ return self._compile("cxx", source_file, object_file)
+
+
+ def _link(self, rule, object_files, binary_file, binary_dir, lib_dirs, libs):
+ if not os.path.isabs(binary_file):
+ binary_file = os.path.join(binary_dir, binary_file)
+ variables = {
+ "descpath": os.path.relpath(binary_file, binary_dir)
+ }
+ if lib_dirs:
+ variables["libdirs"] = " ".join(["-L" + l for l in lib_dirs])
+ if libs:
+ variables["libs"] = " ".join(["-l" + l for l in libs])
+ self.writer.build(binary_file, rule, object_files, variables=variables)
+ return binary_file
+
+
+ def ccld(self, object_files, binary_file, lib_dirs=[], libs=[]):
+ return self._link("ccld", object_files, binary_file, self.build_dir, lib_dirs, libs)
+
+
+ def cxxld(self, object_files, binary_file, lib_dirs=[], libs=[]):
+ return self._link("cxxld", object_files, binary_file, self.build_dir, lib_dirs, libs)
+
+
+ def ar(self, object_files, archive_file):
+ if not os.path.isabs(archive_file):
+ archive_file = os.path.join(self.artifact_dir, archive_file)
+ variables = {
+ "descpath": os.path.relpath(archive_file, self.artifact_dir)
+ }
+ self.writer.build(archive_file, "ar", object_files, variables=variables)
+ return archive_file
+
+ def finalize(self, binary_file, executable_file):
+ if not os.path.isabs(binary_file):
+ binary_file = os.path.join(self.build_dir, binary_file)
+ if not os.path.isabs(executable_file):
+ executable_file = os.path.join(self.artifact_dir, executable_file)
+ variables = {
+ "descpath": os.path.relpath(executable_file, self.artifact_dir)
+ }
+ self.writer.build(executable_file, "finalize", binary_file, variables=variables)
+ return executable_file
+
+ def translate(self, portable_file, native_file):
+ if not os.path.isabs(portable_file):
+ portable_file = os.path.join(self.artifact_dir, portable_file)
+ if not os.path.isabs(native_file):
+ native_file = os.path.join(self.artifact_dir, native_file)
+ variables = {
+ "descpath": os.path.relpath(portable_file, self.artifact_dir),
+ "arch": "x86_64"
+ }
+ self.writer.build(native_file, "translate", portable_file, variables=variables)
+ return native_file
+
+ def run(self, executable_file, target):
+ variables = {
+ "descpath": os.path.relpath(executable_file, self.artifact_dir)
+ }
+ self.writer.build(target, "run", executable_file, variables=variables)
+
+ def install(self, source_file, destination_file, mode=0o644):
+ if not os.path.isabs(destination_file):
+ destination_file = os.path.join(self.prefix_dir, destination_file)
+ variables = {
+ "mode": "0%03o" % mode
+ }
+ self.writer.build(destination_file, "install", source_file, variables=variables)
+ return destination_file
+
+
+parser = argparse.ArgumentParser(description="PThreadPool configuration script")
+parser.add_argument("--with-nacl-sdk", dest="nacl_sdk", default=os.getenv("NACL_SDK_ROOT"),
+ help="Native Client (Pepper) SDK to use")
+parser.add_argument("--prefix", dest="prefix", default="/usr/local")
+
+
+def main():
+ options = parser.parse_args()
+
+ config = Configuration(options)
+
+ root_dir = os.path.dirname(os.path.abspath(__file__))
+
+ config.source_dir = os.path.join(root_dir, "src")
+ config.build_dir = os.path.join(root_dir, "build")
+ config.artifact_dir = os.path.join(root_dir, "artifacts")
+ config.include_dirs = [os.path.join("$nacl_sdk_dir", "include"), os.path.join(root_dir, "include"), os.path.join(root_dir, "src")]
+
+ pthreadpool_object = config.cc("pthreadpool.c")
+ pthreadpool_library = config.ar([pthreadpool_object], "libpthreadpool.a")
+
+ config.source_dir = os.path.join(root_dir, "test")
+ config.build_dir = os.path.join(root_dir, "build", "test")
+ pthreadpool_test_object = config.cxx("pthreadpool.cc")
+ pthreadpool_test_binary = config.cxxld([pthreadpool_object, pthreadpool_test_object], "pthreadpool.bc", libs=["gtest"], lib_dirs=[os.path.join("$nacl_sdk_dir", "lib", "pnacl", "Release")])
+ pthreadpool_test_binary = config.finalize(pthreadpool_test_binary, "pthreadpool.pexe")
+ pthreadpool_test_binary = config.translate(pthreadpool_test_binary, "pthreadpool.nexe")
+ config.run(pthreadpool_test_binary, "test")
+
+ config.writer.default([pthreadpool_library, pthreadpool_test_binary])
+
+ config.writer.build("install", "phony", [
+ config.install(os.path.join(root_dir, "include", "pthreadpool.h"), "include/pthreadpool.h"),
+ config.install(os.path.join(pthreadpool_object), "lib/libpthreadpool.a")])
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/include/pthreadpool.h b/include/pthreadpool.h
new file mode 100644
index 0000000..dc0847d
--- /dev/null
+++ b/include/pthreadpool.h
@@ -0,0 +1,73 @@
+#include <stddef.h>
+
+#ifndef PTHREADPOOL_H
+#define PTHREADPOOL_H
+
+typedef struct pthreadpool* pthreadpool_t;
+
+typedef void (*pthreadpool_function_1d_t)(void*, size_t);
+typedef void (*pthreadpool_function_2d_t)(void*, size_t, size_t);
+typedef void (*pthreadpool_function_3d_t)(void*, size_t, size_t, size_t);
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * Creates a thread pool with the specified number of threads.
+ *
+ * @param[in] threads_count The number of threads in the thread pool.
+ * A value of 0 has special interpretation: it creates a thread for each
+ * processor core available in the system.
+ *
+ * @returns A pointer to an opaque thread pool object.
+ * On error the function returns NULL and sets errno accordingly.
+ */
+pthreadpool_t pthreadpool_create(size_t threads_count);
+
+/**
+ * Queries the number of threads in a thread pool.
+ *
+ * @param[in] threadpool The thread pool to query.
+ *
+ * @returns The number of threads in the thread pool.
+ */
+uint32_t pthreadpool_get_threads_count(pthreadpool_t threadpool);
+
+
+/**
+ * Processes items in parallel using threads from a thread pool.
+ *
+ * When the call returns, all items have been processed and the thread pool is
+ * ready for a new task.
+ *
+ * @note If multiple threads call this function with the same thread pool, the
+ * calls are serialized.
+ *
+ * @param[in] threadpool The thread pool to use for parallelisation.
+ * @param[in] function The function to call for each item.
+ * @param[in] argument The first argument passed to the @a function.
+ * @param[in] items The number of items to process. The @a function
+ * will be called once for each item.
+ */
+void pthreadpool_compute_1d(
+ pthreadpool_t threadpool,
+ pthreadpool_function_1d_t function,
+ void* argument,
+ size_t items);
+
+/**
+ * Terminates threads in the thread pool and releases associated resources.
+ *
+ * @warning Accessing the thread pool after a call to this function constitutes
+ * undefined behaviour and may cause data corruption.
+ *
+ * @param[in,out] threadpool The thread pool to destroy.
+ */
+void pthreadpool_destroy(pthreadpool_t threadpool);
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
+
+#endif /* PTHREADPOOL_H */
diff --git a/src/pthreadpool.c b/src/pthreadpool.c
new file mode 100644
index 0000000..f270204
--- /dev/null
+++ b/src/pthreadpool.c
@@ -0,0 +1,296 @@
+/* Standard C headers */
+#include <stdint.h>
+#include <stdbool.h>
+#include <malloc.h>
+#include <string.h>
+#include <assert.h>
+
+/* POSIX headers */
+#include <pthread.h>
+#include <unistd.h>
+
+/* Library header */
+#include <pthreadpool.h>
+
+#define PTHREADPOOL_CACHELINE_SIZE 64
+#define PTHREADPOOL_CACHELINE_ALIGNED __attribute__((__aligned__(PTHREADPOOL_CACHELINE_SIZE)))
+#define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message)
+
+enum thread_state {
+ thread_state_idle,
+ thread_state_compute_1d,
+ thread_state_shutdown,
+};
+
+struct PTHREADPOOL_CACHELINE_ALIGNED thread_info {
+ /**
+ * Index of the first element in the work range.
+ * Before processing a new element the owning worker thread increments this value.
+ */
+ volatile size_t range_start;
+ /**
+ * Index of the element after the last element of the work range.
+ * Before processing a new element the stealing worker thread decrements this value.
+ */
+ volatile size_t range_end;
+ /**
+ * The number of elements in the work range.
+ * Due to race conditions range_length <= range_end - range_start.
+ * The owning worker thread must decrement this value before incrementing @a range_start.
+ * The stealing worker thread must decrement this value before decrementing @a range_end.
+ */
+ volatile size_t range_length;
+ /**
+ * The active state of the thread.
+ */
+ volatile enum thread_state state;
+ /**
+ * Thread number in the 0..threads_count-1 range.
+ */
+ size_t thread_number;
+ /**
+ * The pthread object corresponding to the thread.
+ */
+ pthread_t thread_object;
+ /**
+ * Condition variable used to wake up the thread.
+ * When the thread is idle, it waits on this condition variable.
+ */
+ pthread_cond_t wakeup_condvar;
+};
+
+PTHREADPOOL_STATIC_ASSERT(sizeof(struct thread_info) % PTHREADPOOL_CACHELINE_SIZE == 0, "thread_info structure must occupy an integer number of cache lines (64 bytes)");
+
+struct PTHREADPOOL_CACHELINE_ALIGNED pthreadpool {
+ /**
+ * The number of threads that signalled completion of an operation.
+ */
+ volatile size_t checkedin_threads;
+ /**
+ * The function to call for each item.
+ */
+ volatile pthreadpool_function_1d_t function;
+ /**
+ * The first argument to the item processing function.
+ */
+ void *volatile argument;
+ /**
+ * Serializes concurrent calls to @a pthreadpool_compute_* from different threads.
+ */
+ pthread_mutex_t execution_mutex;
+ /**
+ * Guards access to the @a checkedin_threads variable.
+ */
+ pthread_mutex_t barrier_mutex;
+ /**
+ * Condition variable to wait until all threads check in.
+ */
+ pthread_cond_t barrier_condvar;
+ /**
+ * Guards access to the @a state variables.
+ */
+ pthread_mutex_t state_mutex;
+ /**
+ * Condition variable to wait for change of @a state variable.
+ */
+ pthread_cond_t state_condvar;
+ /**
+ * The number of threads in the thread pool. Never changes after initialization.
+ */
+ size_t threads_count;
+ /**
+ * Thread information structures that immediately follow this structure.
+ */
+ struct thread_info threads[];
+};
+
+PTHREADPOOL_STATIC_ASSERT(sizeof(struct pthreadpool) % PTHREADPOOL_CACHELINE_SIZE == 0, "pthreadpool structure must occupy an integer number of cache lines (64 bytes)");
+
+static void checkin_worker_thread(struct pthreadpool* threadpool) {
+ pthread_mutex_lock(&threadpool->barrier_mutex);
+ const size_t checkedin_threads = threadpool->checkedin_threads + 1;
+ threadpool->checkedin_threads = checkedin_threads;
+ if (checkedin_threads == threadpool->threads_count) {
+ pthread_cond_signal(&threadpool->barrier_condvar);
+ }
+ pthread_mutex_unlock(&threadpool->barrier_mutex);
+}
+
+static void wait_worker_threads(struct pthreadpool* threadpool) {
+ if (threadpool->checkedin_threads != threadpool->threads_count) {
+ pthread_mutex_lock(&threadpool->barrier_mutex);
+ while (threadpool->checkedin_threads != threadpool->threads_count) {
+ pthread_cond_wait(&threadpool->barrier_condvar, &threadpool->barrier_mutex);
+ };
+ pthread_mutex_unlock(&threadpool->barrier_mutex);
+ }
+}
+
+static void wakeup_worker_threads(struct pthreadpool* threadpool) {
+ pthread_mutex_lock(&threadpool->state_mutex);
+ threadpool->checkedin_threads = 0; /* Locking of barrier_mutex not needed: readers are sleeping */
+ pthread_cond_broadcast(&threadpool->state_condvar);
+ pthread_mutex_unlock(&threadpool->state_mutex); /* Do wake up */
+}
+
+inline static bool atomic_decrement(volatile size_t* value) {
+ size_t actual_value = *value;
+ if (actual_value != 0) {
+ size_t expected_value;
+ do {
+ expected_value = actual_value;
+ const size_t new_value = actual_value - 1;
+ actual_value = __sync_val_compare_and_swap(value, expected_value, new_value);
+ } while ((actual_value != expected_value) && (actual_value != 0));
+ }
+ return actual_value != 0;
+}
+
+static void thread_compute_1d(struct pthreadpool* threadpool, struct thread_info* thread) {
+ const pthreadpool_function_1d_t function = threadpool->function;
+ void *const argument = threadpool->argument;
+ /* Process thread's own range of items */
+ size_t range_start = thread->range_start;
+ while (atomic_decrement(&thread->range_length)) {
+ function(argument, range_start++);
+ }
+ /* Done, now look for other threads' items to steal */
+ const size_t thread_number = thread->thread_number;
+ const size_t threads_count = threadpool->threads_count;
+ for (size_t tid = (thread_number + 1) % threads_count; tid != thread_number; tid = (tid + 1) % threads_count) {
+ struct thread_info* other_thread = &threadpool->threads[tid];
+ if (other_thread->state != thread_state_idle) {
+ while (atomic_decrement(&other_thread->range_length)) {
+ const size_t item_id = __sync_sub_and_fetch(&other_thread->range_end, 1);
+ function(argument, item_id);
+ }
+ }
+ }
+}
+
+static void* thread_main(void* arg) {
+ struct thread_info* thread = (struct thread_info*) arg;
+ struct pthreadpool* threadpool = ((struct pthreadpool*) (thread - thread->thread_number)) - 1;
+
+ /* Check in */
+ checkin_worker_thread(threadpool);
+
+ /* Monitor the state changes and act accordingly */
+ for (;;) {
+ /* Lock the state mutex */
+ pthread_mutex_lock(&threadpool->state_mutex);
+ /* Read the state */
+ enum thread_state state;
+ while ((state = thread->state) == thread_state_idle) {
+ /* Wait for state change */
+ pthread_cond_wait(&threadpool->state_condvar, &threadpool->state_mutex);
+ }
+ /* Read non-idle state */
+ pthread_mutex_unlock(&threadpool->state_mutex);
+ switch (state) {
+ case thread_state_compute_1d:
+ thread_compute_1d(threadpool, thread);
+ break;
+ case thread_state_shutdown:
+ return NULL;
+ case thread_state_idle:
+ /* To inhibit compiler warning */
+ break;
+ }
+ /* Notify the master thread that we finished processing */
+ thread->state = thread_state_idle;
+ checkin_worker_thread(threadpool);
+ };
+}
+
+struct pthreadpool* pthreadpool_create(size_t threads_count) {
+ if (threads_count == 0) {
+ threads_count = (size_t) sysconf(_SC_NPROCESSORS_ONLN);
+ }
+ struct pthreadpool* threadpool = memalign(64, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info));
+ memset(threadpool, 0, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info));
+ threadpool->threads_count = threads_count;
+ pthread_mutex_init(&threadpool->execution_mutex, NULL);
+ pthread_mutex_init(&threadpool->barrier_mutex, NULL);
+ pthread_cond_init(&threadpool->barrier_condvar, NULL);
+ pthread_mutex_init(&threadpool->state_mutex, NULL);
+ pthread_cond_init(&threadpool->state_condvar, NULL);
+
+ for (size_t tid = 0; tid < threads_count; tid++) {
+ threadpool->threads[tid].thread_number = tid;
+ pthread_create(&threadpool->threads[tid].thread_object, NULL, &thread_main, &threadpool->threads[tid]);
+ }
+
+ /* Wait until all threads initialize */
+ wait_worker_threads(threadpool);
+ return threadpool;
+}
+
+uint32_t pthreadpool_get_threads_count(struct pthreadpool* threadpool) {
+ return threadpool->threads_count;
+}
+
+static inline size_t multiply_divide(size_t a, size_t b, size_t d) {
+ #if defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 4)
+ return (size_t) (((uint64_t) a) * ((uint64_t) b)) / ((uint64_t) d);
+ #elif defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 8)
+ return (size_t) (((uint128_t) a) * ((uint128_t) b)) / ((uint128_t) d);
+ #else
+ #error "Unsupported platform"
+ #endif
+}
+
+void pthreadpool_compute_1d(
+ struct pthreadpool* threadpool,
+ pthreadpool_function_1d_t function,
+ void* argument,
+ size_t items)
+{
+ /* Protect the global threadpool structures */
+ pthread_mutex_lock(&threadpool->execution_mutex);
+
+ /* Spread the work between threads */
+ for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
+ struct thread_info* thread = &threadpool->threads[tid];
+ thread->range_start = multiply_divide(items, tid, threadpool->threads_count);
+ thread->range_end = multiply_divide(items, tid + 1, threadpool->threads_count);
+ thread->range_length = thread->range_end - thread->range_start;
+ thread->state = thread_state_compute_1d;
+ }
+
+ /* Setup global arguments */
+ threadpool->function = function;
+ threadpool->argument = argument;
+
+ /* Wake up the threads */
+ wakeup_worker_threads(threadpool);
+
+ /* Wait until the threads finish computation */
+ wait_worker_threads(threadpool);
+
+ /* Unprotect the global threadpool structures */
+ pthread_mutex_unlock(&threadpool->execution_mutex);
+}
+
+void pthreadpool_destroy(struct pthreadpool* threadpool) {
+ /* Update threads' states */
+ for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
+ threadpool->threads[tid].state = thread_state_shutdown;
+ }
+
+ /* Wake up the threads */
+ wakeup_worker_threads(threadpool);
+
+ /* Wait until all threads return */
+ for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
+ pthread_join(threadpool->threads[tid].thread_object, NULL);
+ }
+
+ /* Release resources */
+ pthread_mutex_destroy(&threadpool->execution_mutex);
+ pthread_mutex_destroy(&threadpool->barrier_mutex);
+ pthread_cond_destroy(&threadpool->barrier_condvar);
+ pthread_mutex_destroy(&threadpool->state_mutex);
+ pthread_cond_destroy(&threadpool->state_condvar);
+ free(threadpool);
+}
diff --git a/test/pthreadpool.cc b/test/pthreadpool.cc
new file mode 100644
index 0000000..e1c2559
--- /dev/null
+++ b/test/pthreadpool.cc
@@ -0,0 +1,111 @@
+#include <gtest/gtest.h>
+
+#include <pthreadpool.h>
+
+const size_t itemsCount1D = 1024;
+
+TEST(SetupAndShutdown, Basic) {
+ pthreadpool* threadpool = pthreadpool_create(0);
+ EXPECT_TRUE(threadpool != nullptr);
+ pthreadpool_destroy(threadpool);
+}
+
+static void computeNothing1D(void*, size_t) {
+}
+
+TEST(Compute1D, Basic) {
+ pthreadpool* threadpool = pthreadpool_create(0);
+ EXPECT_TRUE(threadpool != nullptr);
+ pthreadpool_compute_1d(threadpool, computeNothing1D, NULL, itemsCount1D);
+ pthreadpool_destroy(threadpool);
+}
+
+static void checkRange1D(void*, size_t itemId) {
+ EXPECT_LT(itemId, itemsCount1D);
+}
+
+TEST(Compute1D, ValidRange) {
+ pthreadpool* threadpool = pthreadpool_create(0);
+ EXPECT_TRUE(threadpool != nullptr);
+ pthreadpool_compute_1d(threadpool, checkRange1D, NULL, itemsCount1D);
+ pthreadpool_destroy(threadpool);
+}
+
+static void setTrue1D(bool indicators[], size_t itemId) {
+ indicators[itemId] = true;
+}
+
+TEST(Compute1D, AllItemsProcessed) {
+ bool processed[itemsCount1D];
+ memset(processed, 0, sizeof(processed));
+
+ pthreadpool* threadpool = pthreadpool_create(0);
+ EXPECT_TRUE(threadpool != nullptr);
+ pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(setTrue1D), processed, itemsCount1D);
+ for (size_t itemId = 0; itemId < itemsCount1D; itemId++) {
+ EXPECT_TRUE(processed[itemId]) << "Item " << itemId << " not processed";
+ }
+ pthreadpool_destroy(threadpool);
+}
+
+static void increment1D(int counters[], size_t itemId) {
+ counters[itemId] += 1;
+}
+
+TEST(Compute1D, EachItemProcessedOnce) {
+ int processedCount[itemsCount1D];
+ memset(processedCount, 0, sizeof(processedCount));
+
+ pthreadpool* threadpool = pthreadpool_create(0);
+ EXPECT_TRUE(threadpool != nullptr);
+ pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(increment1D), processedCount, itemsCount1D);
+ for (size_t itemId = 0; itemId < itemsCount1D; itemId++) {
+ EXPECT_EQ(1, processedCount[itemId]) << "Item " << itemId << " processed " << processedCount[itemId] << " times";
+ }
+ pthreadpool_destroy(threadpool);
+}
+
+TEST(Compute1D, EachItemProcessedMultipleTimes) {
+ int processedCount[itemsCount1D];
+ memset(processedCount, 0, sizeof(processedCount));
+ const size_t iterations = 100;
+
+ pthreadpool* threadpool = pthreadpool_create(0);
+ EXPECT_TRUE(threadpool != nullptr);
+
+ for (size_t iteration = 0; iteration < iterations; iteration++) {
+ pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(increment1D), processedCount, itemsCount1D);
+ }
+ for (size_t itemId = 0; itemId < itemsCount1D; itemId++) {
+ EXPECT_EQ(iterations, processedCount[itemId]) << "Item " << itemId << " processed " << processedCount[itemId] << " times";
+ }
+ pthreadpool_destroy(threadpool);
+}
+
+static void workImbalance1D(volatile size_t* computedItems, size_t itemId) {
+ __sync_fetch_and_add(computedItems, 1);
+ if (itemId == 0) {
+ /* Wait until all items are computed */
+ while (*computedItems != itemsCount1D) {
+ __sync_synchronize();
+ }
+ }
+}
+
+TEST(Compute1D, WorkStealing) {
+ volatile size_t computedItems = 0;
+
+ pthreadpool* threadpool = pthreadpool_create(0);
+ EXPECT_TRUE(threadpool != nullptr);
+
+ pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(workImbalance1D), reinterpret_cast<void*>(const_cast<size_t*>(&computedItems)), itemsCount1D);
+ EXPECT_EQ(computedItems, itemsCount1D);
+
+ pthreadpool_destroy(threadpool);
+}
+
+int main(int argc, char* argv[]) {
+ setenv("TERM", "xterm-256color", 0);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}