diff options
author | Marat Dukhan <maratek@gmail.com> | 2015-08-22 17:46:29 -0400 |
---|---|---|
committer | Marat Dukhan <maratek@gmail.com> | 2015-08-22 17:46:29 -0400 |
commit | 0a312191e15b30e8e50a89077edc92bc2ee5682d (patch) | |
tree | 673d7b763a6c95ac4d29499b1d2f18da9778a441 | |
parent | faf6ff30a6ac3fc9be846cf3a89942ebec78aeeb (diff) | |
download | pthreadpool-0a312191e15b30e8e50a89077edc92bc2ee5682d.tar.gz |
Initial thread pool implementation
-rw-r--r-- | .gitignore | 17 | ||||
-rwxr-xr-x | configure.py | 204 | ||||
-rw-r--r-- | include/pthreadpool.h | 73 | ||||
-rw-r--r-- | src/pthreadpool.c | 296 | ||||
-rw-r--r-- | test/pthreadpool.cc | 111 |
5 files changed, 701 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f42c5bc --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +# Ninja files +.ninja_deps +.ninja_log +build.ninja + +# Build objects and artifacts +build/* +artifacts/* + +# System files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db diff --git a/configure.py b/configure.py new file mode 100755 index 0000000..72db0c4 --- /dev/null +++ b/configure.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python + +import os +import sys +import glob +import argparse +import ninja_syntax + + +class Configuration: + + def __init__(self, options, ninja_build_file=os.path.join(os.path.dirname(os.path.abspath(__file__)), "build.ninja")): + self.output = open(ninja_build_file, "w") + self.writer = ninja_syntax.Writer(self.output) + self.source_dir = None + self.build_dir = None + self.artifact_dir = None + self.prefix_dir = options.prefix + self.include_dirs = [] + self.object_ext = ".bc" + + # Variables + self.writer.variable("nacl_sdk_dir", options.nacl_sdk) + self._set_pnacl_vars() + self.writer.variable("cflags", "-std=gnu11") + self.writer.variable("cxxflags", "-std=gnu++11") + self.writer.variable("optflags", "-O3") + + # Rules + self.writer.rule("cc", "$pnacl_cc -o $out -c $in -MMD -MF $out.d $optflags $cflags $includes", + deps="gcc", depfile="$out.d", + description="CC[PNaCl] $descpath") + self.writer.rule("cxx", "$pnacl_cxx -o $out -c $in -MMD -MF $out.d $optflags $cxxflags $includes", + deps="gcc", depfile="$out.d", + description="CXX[PNaCl] $descpath") + self.writer.rule("ccld", "$pnacl_cc -o $out $in $libs $libdirs $ldflags", + description="CCLD[PNaCl] $descpath") + self.writer.rule("cxxld", "$pnacl_cxx -o $out $in $libs $libdirs $ldflags", + description="CXXLD[PNaCl] $descpath") + self.writer.rule("ar", "$pnacl_ar rcs $out $in", + description="AR[PNaCl] $descpath") + self.writer.rule("finalize", "$pnacl_finalize $finflags -o $out $in", + description="FINALIZE[PNaCl] $descpath") + self.writer.rule("translate", "$pnacl_translate -arch $arch -o $out $in", + description="TRANSLATE[PNaCl] $descpath") + self.writer.rule("run", "$pnacl_sel_ldr $in", + description="RUN[PNaCl] $descpath", pool="console") + self.writer.rule("install", "install -m $mode $in $out", + description="INSTALL $out") + + + def _set_pnacl_vars(self): + if sys.platform == "win32": + self.writer.variable("pnacl_toolchain_dir", "$nacl_sdk_dir/toolchain/win_pnacl") + self.writer.variable("pnacl_cc", "$pnacl_toolchain_dir/bin/pnacl-clang.bat") + self.writer.variable("pnacl_cxx", "$pnacl_toolchain_dir/bin/pnacl-clang++.bat") + self.writer.variable("pnacl_ar", "$pnacl_toolchain_dir/bin/pnacl-ar.bat") + self.writer.variable("pnacl_finalize", "$pnacl_toolchain_dir/bin/pnacl-finalize.bat") + self.writer.variable("pnacl_translate", "$pnacl_toolchain_dir/bin/pnacl-translate.bat") + elif sys.platform == "linux2" or sys.platform == "darwin": + if sys.platform == "linux2": + self.writer.variable("pnacl_toolchain_dir", "$nacl_sdk_dir/toolchain/linux_pnacl") + else: + self.writer.variable("pnacl_toolchain_dir", "$nacl_sdk_dir/toolchain/mac_pnacl") + self.writer.variable("pnacl_cc", "$pnacl_toolchain_dir/bin/pnacl-clang") + self.writer.variable("pnacl_cxx", "$pnacl_toolchain_dir/bin/pnacl-clang++") + self.writer.variable("pnacl_ar", "$pnacl_toolchain_dir/bin/pnacl-ar") + self.writer.variable("pnacl_finalize", "$pnacl_toolchain_dir/bin/pnacl-finalize") + self.writer.variable("pnacl_translate", "$pnacl_toolchain_dir/bin/pnacl-translate") + else: + raise OSError("Unsupported platform: " + sys.platform) + self.writer.variable("pnacl_sel_ldr", "$nacl_sdk_dir/tools/sel_ldr.py") + + + def _compile(self, rule, source_file, object_file): + if not os.path.isabs(source_file): + source_file = os.path.join(self.source_dir, source_file) + if object_file is None: + object_file = os.path.join(self.build_dir, os.path.relpath(source_file, self.source_dir)) + self.object_ext + variables = { + "descpath": os.path.relpath(source_file, self.source_dir) + } + if self.include_dirs: + variables["includes"] = " ".join(["-I" + i for i in self.include_dirs]) + self.writer.build(object_file, rule, source_file, variables=variables) + return object_file + + + def cc(self, source_file, object_file=None): + return self._compile("cc", source_file, object_file) + + + def cxx(self, source_file, object_file=None): + return self._compile("cxx", source_file, object_file) + + + def _link(self, rule, object_files, binary_file, binary_dir, lib_dirs, libs): + if not os.path.isabs(binary_file): + binary_file = os.path.join(binary_dir, binary_file) + variables = { + "descpath": os.path.relpath(binary_file, binary_dir) + } + if lib_dirs: + variables["libdirs"] = " ".join(["-L" + l for l in lib_dirs]) + if libs: + variables["libs"] = " ".join(["-l" + l for l in libs]) + self.writer.build(binary_file, rule, object_files, variables=variables) + return binary_file + + + def ccld(self, object_files, binary_file, lib_dirs=[], libs=[]): + return self._link("ccld", object_files, binary_file, self.build_dir, lib_dirs, libs) + + + def cxxld(self, object_files, binary_file, lib_dirs=[], libs=[]): + return self._link("cxxld", object_files, binary_file, self.build_dir, lib_dirs, libs) + + + def ar(self, object_files, archive_file): + if not os.path.isabs(archive_file): + archive_file = os.path.join(self.artifact_dir, archive_file) + variables = { + "descpath": os.path.relpath(archive_file, self.artifact_dir) + } + self.writer.build(archive_file, "ar", object_files, variables=variables) + return archive_file + + def finalize(self, binary_file, executable_file): + if not os.path.isabs(binary_file): + binary_file = os.path.join(self.build_dir, binary_file) + if not os.path.isabs(executable_file): + executable_file = os.path.join(self.artifact_dir, executable_file) + variables = { + "descpath": os.path.relpath(executable_file, self.artifact_dir) + } + self.writer.build(executable_file, "finalize", binary_file, variables=variables) + return executable_file + + def translate(self, portable_file, native_file): + if not os.path.isabs(portable_file): + portable_file = os.path.join(self.artifact_dir, portable_file) + if not os.path.isabs(native_file): + native_file = os.path.join(self.artifact_dir, native_file) + variables = { + "descpath": os.path.relpath(portable_file, self.artifact_dir), + "arch": "x86_64" + } + self.writer.build(native_file, "translate", portable_file, variables=variables) + return native_file + + def run(self, executable_file, target): + variables = { + "descpath": os.path.relpath(executable_file, self.artifact_dir) + } + self.writer.build(target, "run", executable_file, variables=variables) + + def install(self, source_file, destination_file, mode=0o644): + if not os.path.isabs(destination_file): + destination_file = os.path.join(self.prefix_dir, destination_file) + variables = { + "mode": "0%03o" % mode + } + self.writer.build(destination_file, "install", source_file, variables=variables) + return destination_file + + +parser = argparse.ArgumentParser(description="PThreadPool configuration script") +parser.add_argument("--with-nacl-sdk", dest="nacl_sdk", default=os.getenv("NACL_SDK_ROOT"), + help="Native Client (Pepper) SDK to use") +parser.add_argument("--prefix", dest="prefix", default="/usr/local") + + +def main(): + options = parser.parse_args() + + config = Configuration(options) + + root_dir = os.path.dirname(os.path.abspath(__file__)) + + config.source_dir = os.path.join(root_dir, "src") + config.build_dir = os.path.join(root_dir, "build") + config.artifact_dir = os.path.join(root_dir, "artifacts") + config.include_dirs = [os.path.join("$nacl_sdk_dir", "include"), os.path.join(root_dir, "include"), os.path.join(root_dir, "src")] + + pthreadpool_object = config.cc("pthreadpool.c") + pthreadpool_library = config.ar([pthreadpool_object], "libpthreadpool.a") + + config.source_dir = os.path.join(root_dir, "test") + config.build_dir = os.path.join(root_dir, "build", "test") + pthreadpool_test_object = config.cxx("pthreadpool.cc") + pthreadpool_test_binary = config.cxxld([pthreadpool_object, pthreadpool_test_object], "pthreadpool.bc", libs=["gtest"], lib_dirs=[os.path.join("$nacl_sdk_dir", "lib", "pnacl", "Release")]) + pthreadpool_test_binary = config.finalize(pthreadpool_test_binary, "pthreadpool.pexe") + pthreadpool_test_binary = config.translate(pthreadpool_test_binary, "pthreadpool.nexe") + config.run(pthreadpool_test_binary, "test") + + config.writer.default([pthreadpool_library, pthreadpool_test_binary]) + + config.writer.build("install", "phony", [ + config.install(os.path.join(root_dir, "include", "pthreadpool.h"), "include/pthreadpool.h"), + config.install(os.path.join(pthreadpool_object), "lib/libpthreadpool.a")]) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/include/pthreadpool.h b/include/pthreadpool.h new file mode 100644 index 0000000..dc0847d --- /dev/null +++ b/include/pthreadpool.h @@ -0,0 +1,73 @@ +#include <stddef.h> + +#ifndef PTHREADPOOL_H +#define PTHREADPOOL_H + +typedef struct pthreadpool* pthreadpool_t; + +typedef void (*pthreadpool_function_1d_t)(void*, size_t); +typedef void (*pthreadpool_function_2d_t)(void*, size_t, size_t); +typedef void (*pthreadpool_function_3d_t)(void*, size_t, size_t, size_t); + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Creates a thread pool with the specified number of threads. + * + * @param[in] threads_count The number of threads in the thread pool. + * A value of 0 has special interpretation: it creates a thread for each + * processor core available in the system. + * + * @returns A pointer to an opaque thread pool object. + * On error the function returns NULL and sets errno accordingly. + */ +pthreadpool_t pthreadpool_create(size_t threads_count); + +/** + * Queries the number of threads in a thread pool. + * + * @param[in] threadpool The thread pool to query. + * + * @returns The number of threads in the thread pool. + */ +uint32_t pthreadpool_get_threads_count(pthreadpool_t threadpool); + + +/** + * Processes items in parallel using threads from a thread pool. + * + * When the call returns, all items have been processed and the thread pool is + * ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param[in] threadpool The thread pool to use for parallelisation. + * @param[in] function The function to call for each item. + * @param[in] argument The first argument passed to the @a function. + * @param[in] items The number of items to process. The @a function + * will be called once for each item. + */ +void pthreadpool_compute_1d( + pthreadpool_t threadpool, + pthreadpool_function_1d_t function, + void* argument, + size_t items); + +/** + * Terminates threads in the thread pool and releases associated resources. + * + * @warning Accessing the thread pool after a call to this function constitutes + * undefined behaviour and may cause data corruption. + * + * @param[in,out] threadpool The thread pool to destroy. + */ +void pthreadpool_destroy(pthreadpool_t threadpool); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* PTHREADPOOL_H */ diff --git a/src/pthreadpool.c b/src/pthreadpool.c new file mode 100644 index 0000000..f270204 --- /dev/null +++ b/src/pthreadpool.c @@ -0,0 +1,296 @@ +/* Standard C headers */ +#include <stdint.h> +#include <stdbool.h> +#include <malloc.h> +#include <string.h> +#include <assert.h> + +/* POSIX headers */ +#include <pthread.h> +#include <unistd.h> + +/* Library header */ +#include <pthreadpool.h> + +#define PTHREADPOOL_CACHELINE_SIZE 64 +#define PTHREADPOOL_CACHELINE_ALIGNED __attribute__((__aligned__(PTHREADPOOL_CACHELINE_SIZE))) +#define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message) + +enum thread_state { + thread_state_idle, + thread_state_compute_1d, + thread_state_shutdown, +}; + +struct PTHREADPOOL_CACHELINE_ALIGNED thread_info { + /** + * Index of the first element in the work range. + * Before processing a new element the owning worker thread increments this value. + */ + volatile size_t range_start; + /** + * Index of the element after the last element of the work range. + * Before processing a new element the stealing worker thread decrements this value. + */ + volatile size_t range_end; + /** + * The number of elements in the work range. + * Due to race conditions range_length <= range_end - range_start. + * The owning worker thread must decrement this value before incrementing @a range_start. + * The stealing worker thread must decrement this value before decrementing @a range_end. + */ + volatile size_t range_length; + /** + * The active state of the thread. + */ + volatile enum thread_state state; + /** + * Thread number in the 0..threads_count-1 range. + */ + size_t thread_number; + /** + * The pthread object corresponding to the thread. + */ + pthread_t thread_object; + /** + * Condition variable used to wake up the thread. + * When the thread is idle, it waits on this condition variable. + */ + pthread_cond_t wakeup_condvar; +}; + +PTHREADPOOL_STATIC_ASSERT(sizeof(struct thread_info) % PTHREADPOOL_CACHELINE_SIZE == 0, "thread_info structure must occupy an integer number of cache lines (64 bytes)"); + +struct PTHREADPOOL_CACHELINE_ALIGNED pthreadpool { + /** + * The number of threads that signalled completion of an operation. + */ + volatile size_t checkedin_threads; + /** + * The function to call for each item. + */ + volatile pthreadpool_function_1d_t function; + /** + * The first argument to the item processing function. + */ + void *volatile argument; + /** + * Serializes concurrent calls to @a pthreadpool_compute_* from different threads. + */ + pthread_mutex_t execution_mutex; + /** + * Guards access to the @a checkedin_threads variable. + */ + pthread_mutex_t barrier_mutex; + /** + * Condition variable to wait until all threads check in. + */ + pthread_cond_t barrier_condvar; + /** + * Guards access to the @a state variables. + */ + pthread_mutex_t state_mutex; + /** + * Condition variable to wait for change of @a state variable. + */ + pthread_cond_t state_condvar; + /** + * The number of threads in the thread pool. Never changes after initialization. + */ + size_t threads_count; + /** + * Thread information structures that immediately follow this structure. + */ + struct thread_info threads[]; +}; + +PTHREADPOOL_STATIC_ASSERT(sizeof(struct pthreadpool) % PTHREADPOOL_CACHELINE_SIZE == 0, "pthreadpool structure must occupy an integer number of cache lines (64 bytes)"); + +static void checkin_worker_thread(struct pthreadpool* threadpool) { + pthread_mutex_lock(&threadpool->barrier_mutex); + const size_t checkedin_threads = threadpool->checkedin_threads + 1; + threadpool->checkedin_threads = checkedin_threads; + if (checkedin_threads == threadpool->threads_count) { + pthread_cond_signal(&threadpool->barrier_condvar); + } + pthread_mutex_unlock(&threadpool->barrier_mutex); +} + +static void wait_worker_threads(struct pthreadpool* threadpool) { + if (threadpool->checkedin_threads != threadpool->threads_count) { + pthread_mutex_lock(&threadpool->barrier_mutex); + while (threadpool->checkedin_threads != threadpool->threads_count) { + pthread_cond_wait(&threadpool->barrier_condvar, &threadpool->barrier_mutex); + }; + pthread_mutex_unlock(&threadpool->barrier_mutex); + } +} + +static void wakeup_worker_threads(struct pthreadpool* threadpool) { + pthread_mutex_lock(&threadpool->state_mutex); + threadpool->checkedin_threads = 0; /* Locking of barrier_mutex not needed: readers are sleeping */ + pthread_cond_broadcast(&threadpool->state_condvar); + pthread_mutex_unlock(&threadpool->state_mutex); /* Do wake up */ +} + +inline static bool atomic_decrement(volatile size_t* value) { + size_t actual_value = *value; + if (actual_value != 0) { + size_t expected_value; + do { + expected_value = actual_value; + const size_t new_value = actual_value - 1; + actual_value = __sync_val_compare_and_swap(value, expected_value, new_value); + } while ((actual_value != expected_value) && (actual_value != 0)); + } + return actual_value != 0; +} + +static void thread_compute_1d(struct pthreadpool* threadpool, struct thread_info* thread) { + const pthreadpool_function_1d_t function = threadpool->function; + void *const argument = threadpool->argument; + /* Process thread's own range of items */ + size_t range_start = thread->range_start; + while (atomic_decrement(&thread->range_length)) { + function(argument, range_start++); + } + /* Done, now look for other threads' items to steal */ + const size_t thread_number = thread->thread_number; + const size_t threads_count = threadpool->threads_count; + for (size_t tid = (thread_number + 1) % threads_count; tid != thread_number; tid = (tid + 1) % threads_count) { + struct thread_info* other_thread = &threadpool->threads[tid]; + if (other_thread->state != thread_state_idle) { + while (atomic_decrement(&other_thread->range_length)) { + const size_t item_id = __sync_sub_and_fetch(&other_thread->range_end, 1); + function(argument, item_id); + } + } + } +} + +static void* thread_main(void* arg) { + struct thread_info* thread = (struct thread_info*) arg; + struct pthreadpool* threadpool = ((struct pthreadpool*) (thread - thread->thread_number)) - 1; + + /* Check in */ + checkin_worker_thread(threadpool); + + /* Monitor the state changes and act accordingly */ + for (;;) { + /* Lock the state mutex */ + pthread_mutex_lock(&threadpool->state_mutex); + /* Read the state */ + enum thread_state state; + while ((state = thread->state) == thread_state_idle) { + /* Wait for state change */ + pthread_cond_wait(&threadpool->state_condvar, &threadpool->state_mutex); + } + /* Read non-idle state */ + pthread_mutex_unlock(&threadpool->state_mutex); + switch (state) { + case thread_state_compute_1d: + thread_compute_1d(threadpool, thread); + break; + case thread_state_shutdown: + return NULL; + case thread_state_idle: + /* To inhibit compiler warning */ + break; + } + /* Notify the master thread that we finished processing */ + thread->state = thread_state_idle; + checkin_worker_thread(threadpool); + }; +} + +struct pthreadpool* pthreadpool_create(size_t threads_count) { + if (threads_count == 0) { + threads_count = (size_t) sysconf(_SC_NPROCESSORS_ONLN); + } + struct pthreadpool* threadpool = memalign(64, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info)); + memset(threadpool, 0, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info)); + threadpool->threads_count = threads_count; + pthread_mutex_init(&threadpool->execution_mutex, NULL); + pthread_mutex_init(&threadpool->barrier_mutex, NULL); + pthread_cond_init(&threadpool->barrier_condvar, NULL); + pthread_mutex_init(&threadpool->state_mutex, NULL); + pthread_cond_init(&threadpool->state_condvar, NULL); + + for (size_t tid = 0; tid < threads_count; tid++) { + threadpool->threads[tid].thread_number = tid; + pthread_create(&threadpool->threads[tid].thread_object, NULL, &thread_main, &threadpool->threads[tid]); + } + + /* Wait until all threads initialize */ + wait_worker_threads(threadpool); + return threadpool; +} + +uint32_t pthreadpool_get_threads_count(struct pthreadpool* threadpool) { + return threadpool->threads_count; +} + +static inline size_t multiply_divide(size_t a, size_t b, size_t d) { + #if defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 4) + return (size_t) (((uint64_t) a) * ((uint64_t) b)) / ((uint64_t) d); + #elif defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 8) + return (size_t) (((uint128_t) a) * ((uint128_t) b)) / ((uint128_t) d); + #else + #error "Unsupported platform" + #endif +} + +void pthreadpool_compute_1d( + struct pthreadpool* threadpool, + pthreadpool_function_1d_t function, + void* argument, + size_t items) +{ + /* Protect the global threadpool structures */ + pthread_mutex_lock(&threadpool->execution_mutex); + + /* Spread the work between threads */ + for (size_t tid = 0; tid < threadpool->threads_count; tid++) { + struct thread_info* thread = &threadpool->threads[tid]; + thread->range_start = multiply_divide(items, tid, threadpool->threads_count); + thread->range_end = multiply_divide(items, tid + 1, threadpool->threads_count); + thread->range_length = thread->range_end - thread->range_start; + thread->state = thread_state_compute_1d; + } + + /* Setup global arguments */ + threadpool->function = function; + threadpool->argument = argument; + + /* Wake up the threads */ + wakeup_worker_threads(threadpool); + + /* Wait until the threads finish computation */ + wait_worker_threads(threadpool); + + /* Unprotect the global threadpool structures */ + pthread_mutex_unlock(&threadpool->execution_mutex); +} + +void pthreadpool_destroy(struct pthreadpool* threadpool) { + /* Update threads' states */ + for (size_t tid = 0; tid < threadpool->threads_count; tid++) { + threadpool->threads[tid].state = thread_state_shutdown; + } + + /* Wake up the threads */ + wakeup_worker_threads(threadpool); + + /* Wait until all threads return */ + for (size_t tid = 0; tid < threadpool->threads_count; tid++) { + pthread_join(threadpool->threads[tid].thread_object, NULL); + } + + /* Release resources */ + pthread_mutex_destroy(&threadpool->execution_mutex); + pthread_mutex_destroy(&threadpool->barrier_mutex); + pthread_cond_destroy(&threadpool->barrier_condvar); + pthread_mutex_destroy(&threadpool->state_mutex); + pthread_cond_destroy(&threadpool->state_condvar); + free(threadpool); +} diff --git a/test/pthreadpool.cc b/test/pthreadpool.cc new file mode 100644 index 0000000..e1c2559 --- /dev/null +++ b/test/pthreadpool.cc @@ -0,0 +1,111 @@ +#include <gtest/gtest.h> + +#include <pthreadpool.h> + +const size_t itemsCount1D = 1024; + +TEST(SetupAndShutdown, Basic) { + pthreadpool* threadpool = pthreadpool_create(0); + EXPECT_TRUE(threadpool != nullptr); + pthreadpool_destroy(threadpool); +} + +static void computeNothing1D(void*, size_t) { +} + +TEST(Compute1D, Basic) { + pthreadpool* threadpool = pthreadpool_create(0); + EXPECT_TRUE(threadpool != nullptr); + pthreadpool_compute_1d(threadpool, computeNothing1D, NULL, itemsCount1D); + pthreadpool_destroy(threadpool); +} + +static void checkRange1D(void*, size_t itemId) { + EXPECT_LT(itemId, itemsCount1D); +} + +TEST(Compute1D, ValidRange) { + pthreadpool* threadpool = pthreadpool_create(0); + EXPECT_TRUE(threadpool != nullptr); + pthreadpool_compute_1d(threadpool, checkRange1D, NULL, itemsCount1D); + pthreadpool_destroy(threadpool); +} + +static void setTrue1D(bool indicators[], size_t itemId) { + indicators[itemId] = true; +} + +TEST(Compute1D, AllItemsProcessed) { + bool processed[itemsCount1D]; + memset(processed, 0, sizeof(processed)); + + pthreadpool* threadpool = pthreadpool_create(0); + EXPECT_TRUE(threadpool != nullptr); + pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(setTrue1D), processed, itemsCount1D); + for (size_t itemId = 0; itemId < itemsCount1D; itemId++) { + EXPECT_TRUE(processed[itemId]) << "Item " << itemId << " not processed"; + } + pthreadpool_destroy(threadpool); +} + +static void increment1D(int counters[], size_t itemId) { + counters[itemId] += 1; +} + +TEST(Compute1D, EachItemProcessedOnce) { + int processedCount[itemsCount1D]; + memset(processedCount, 0, sizeof(processedCount)); + + pthreadpool* threadpool = pthreadpool_create(0); + EXPECT_TRUE(threadpool != nullptr); + pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(increment1D), processedCount, itemsCount1D); + for (size_t itemId = 0; itemId < itemsCount1D; itemId++) { + EXPECT_EQ(1, processedCount[itemId]) << "Item " << itemId << " processed " << processedCount[itemId] << " times"; + } + pthreadpool_destroy(threadpool); +} + +TEST(Compute1D, EachItemProcessedMultipleTimes) { + int processedCount[itemsCount1D]; + memset(processedCount, 0, sizeof(processedCount)); + const size_t iterations = 100; + + pthreadpool* threadpool = pthreadpool_create(0); + EXPECT_TRUE(threadpool != nullptr); + + for (size_t iteration = 0; iteration < iterations; iteration++) { + pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(increment1D), processedCount, itemsCount1D); + } + for (size_t itemId = 0; itemId < itemsCount1D; itemId++) { + EXPECT_EQ(iterations, processedCount[itemId]) << "Item " << itemId << " processed " << processedCount[itemId] << " times"; + } + pthreadpool_destroy(threadpool); +} + +static void workImbalance1D(volatile size_t* computedItems, size_t itemId) { + __sync_fetch_and_add(computedItems, 1); + if (itemId == 0) { + /* Wait until all items are computed */ + while (*computedItems != itemsCount1D) { + __sync_synchronize(); + } + } +} + +TEST(Compute1D, WorkStealing) { + volatile size_t computedItems = 0; + + pthreadpool* threadpool = pthreadpool_create(0); + EXPECT_TRUE(threadpool != nullptr); + + pthreadpool_compute_1d(threadpool, reinterpret_cast<pthreadpool_function_1d_t>(workImbalance1D), reinterpret_cast<void*>(const_cast<size_t*>(&computedItems)), itemsCount1D); + EXPECT_EQ(computedItems, itemsCount1D); + + pthreadpool_destroy(threadpool); +} + +int main(int argc, char* argv[]) { + setenv("TERM", "xterm-256color", 0); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |