diff options
-rwxr-xr-x | checkbuild.py | 2 | ||||
-rw-r--r-- | ndk/test/devices.py | 6 | ||||
-rw-r--r-- | ndk/test/test_workqueue.py | 2 | ||||
-rw-r--r-- | ndk/workqueue.py | 51 | ||||
-rwxr-xr-x | run_tests.py | 108 | ||||
-rw-r--r-- | tests/testlib.py | 3 |
6 files changed, 127 insertions, 45 deletions
diff --git a/checkbuild.py b/checkbuild.py index f15eac42c..0ccaf41a8 100755 --- a/checkbuild.py +++ b/checkbuild.py @@ -1277,7 +1277,7 @@ class SourceProperties(ndk.builds.Module): ]) -def launch_build(module, out_dir, dist_dir, args, log_dir): +def launch_build(_worker_data, module, out_dir, dist_dir, args, log_dir): log_path = os.path.join(log_dir, module.name) + '.log' tee = subprocess.Popen(["tee", log_path], stdin=subprocess.PIPE) try: diff --git a/ndk/test/devices.py b/ndk/test/devices.py index 37e74ab4b..5eb7e5ac5 100644 --- a/ndk/test/devices.py +++ b/ndk/test/devices.py @@ -322,6 +322,10 @@ class DeviceFleet(object): return self.devices[version].keys() +def create_device(_worker_data, serial, precache): + return Device(serial, precache) + + def get_all_attached_devices(workqueue): """Returns a list of all connected devices.""" if distutils.spawn.find_executable('adb') is None: @@ -352,7 +356,7 @@ def get_all_attached_devices(workqueue): # Caching all the device details via getprop can actually take quite a # bit of time. Do it in parallel to minimize the cost. - workqueue.add_task(Device, serial, True) + workqueue.add_task(create_device, serial, True) devices = [] while not workqueue.finished(): diff --git a/ndk/test/test_workqueue.py b/ndk/test/test_workqueue.py index ba31afa80..8778aa8c2 100644 --- a/ndk/test/test_workqueue.py +++ b/ndk/test/test_workqueue.py @@ -24,7 +24,7 @@ import unittest import ndk.workqueue -def put(i): +def put(_worker_data, i): """Returns an the passed argument.""" return i diff --git a/ndk/workqueue.py b/ndk/workqueue.py index 6735f5711..d163926f4 100644 --- a/ndk/workqueue.py +++ b/ndk/workqueue.py @@ -35,7 +35,7 @@ def worker_sigterm_handler(_signum, _frame): sys.exit() -def _flush_queue(queue): +def flush_queue(queue): """Flushes all pending items from a Queue.""" try: while True: @@ -58,7 +58,7 @@ class TaskError(Exception): super(TaskError, self).__init__(trace) -def worker_main(task_queue, result_queue): +def worker_main(worker_data, task_queue, result_queue): """Main loop for worker processes. Args: @@ -72,7 +72,7 @@ def worker_main(task_queue, result_queue): logger().debug('worker %d waiting for work', os.getpid()) task = task_queue.get() logger().debug('worker %d running task', os.getpid()) - result = task.run() + result = task.run(worker_data) logger().debug('worker %d putting result', os.getpid()) result_queue.put(result) except SystemExit: @@ -104,9 +104,9 @@ class Task(object): self.args = args self.kwargs = kwargs - def run(self): + def run(self, worker_data): """Invokes the task.""" - return self.func(*self.args, **self.kwargs) + return self.func(worker_data, *self.args, **self.kwargs) class ProcessPoolWorkQueue(object): @@ -114,7 +114,8 @@ class ProcessPoolWorkQueue(object): join_timeout = 8 # Timeout for join before trying SIGKILL. - def __init__(self, num_workers=multiprocessing.cpu_count()): + def __init__(self, num_workers=multiprocessing.cpu_count(), + task_queue=None, result_queue=None, worker_data=None): """Creates a WorkQueue. Worker threads are spawned immediately and remain live until both @@ -122,6 +123,14 @@ class ProcessPoolWorkQueue(object): Args: num_workers: Number of worker processes to spawn. + task_queue: multiprocessing.Queue for tasks. Allows multiple work + queues to share a single task queue. If None, the work queue + creates its own. + result_queue: multiprocessing.Queue for results. Allows multiple + work queues to share a single result queue. If None, the work + queue creates its own. + worker_data: Data to be passed to every task run by this work + queue. """ if sys.platform == 'win32': # TODO(danalbert): Port ProcessPoolWorkQueue to Windows. @@ -129,8 +138,20 @@ class ProcessPoolWorkQueue(object): # groups, which are not supported on Windows. raise NotImplementedError - self.task_queue = multiprocessing.Queue() - self.result_queue = multiprocessing.Queue() + self.task_queue = task_queue + self.owns_task_queue = False + if task_queue is None: + self.task_queue = multiprocessing.Queue() + self.owns_task_queue = True + + self.result_queue = result_queue + self.owns_result_queue = False + if result_queue is None: + self.result_queue = multiprocessing.Queue() + self.owns_result_queue = True + + self.worker_data = worker_data + self.workers = [] # multiprocessing.JoinableQueue's join isn't able to implement # finished() because it doesn't come in a non-blocking flavor. @@ -176,8 +197,10 @@ class ProcessPoolWorkQueue(object): We call _flush after all workers have been terminated to ensure that we can exit cleanly. """ - _flush_queue(self.task_queue) - _flush_queue(self.result_queue) + if self.owns_task_queue: + flush_queue(self.task_queue) + if self.owns_result_queue: + flush_queue(self.result_queue) def join(self): """Waits for all worker processes to exit.""" @@ -203,7 +226,8 @@ class ProcessPoolWorkQueue(object): """ for _ in range(num_workers): worker = multiprocessing.Process( - target=worker_main, args=(self.task_queue, self.result_queue)) + target=worker_main, + args=(self.worker_data, self.task_queue, self.result_queue)) worker.start() self.workers.append(worker) @@ -214,9 +238,10 @@ class DummyWorkQueue(object): Useful for debugging when trying to determine if an issue is being caused by multiprocess specific behavior. """ - def __init__(self): + def __init__(self, worker_data=None): """Creates a SerialWorkQueue.""" self.task_queue = collections.deque() + self.worker_data = worker_data def add_task(self, func, *args, **kwargs): """Queues up a new task for execution. @@ -234,7 +259,7 @@ class DummyWorkQueue(object): """Executes a task and returns the result.""" task = self.task_queue.popleft() try: - return task.run() + return task.run(self.worker_data) except: trace = ''.join(traceback.format_exception(*sys.exc_info())) raise TaskError(trace) diff --git a/run_tests.py b/run_tests.py index 48b67a70c..cbe59c48e 100755 --- a/run_tests.py +++ b/run_tests.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import json import logging +import multiprocessing import os import posixpath import random @@ -183,10 +184,10 @@ class LibcxxTestCase(TestCase): class TestRun(object): - """A test case mapped to the device it will run on.""" - def __init__(self, test_case, device): + """A test case mapped to the device group it will run on.""" + def __init__(self, test_case, device_group): self.test_case = test_case - self.device = device + self.device_group = device_group @property def name(self): @@ -200,17 +201,17 @@ class TestRun(object): def config(self): return self.test_case.config - def make_result(self, adb_result_tuple): + def make_result(self, adb_result_tuple, device): status, out, _ = adb_result_tuple if status == 0: result = ndk.test.result.Success(self) else: - out = '\n'.join([str(self.device), out]) + out = '\n'.join([str(device), out]) result = ndk.test.result.Failure(self, out) - return self.fixup_xfail(result) + return self.fixup_xfail(result, device) - def fixup_xfail(self, result): - config, bug = self.test_case.check_broken(self.device) + def fixup_xfail(self, result, device): + config, bug = self.test_case.check_broken(device) if config is not None: if result.failed(): return ndk.test.result.ExpectedFailure(self, config, bug) @@ -219,12 +220,12 @@ class TestRun(object): raise ValueError('Test result must have either failed or passed.') return result - def run(self): - config = self.test_case.check_unsupported(self.device) + def run(self, device): + config = self.test_case.check_unsupported(device) if config is not None: message = 'test unsupported for {}'.format(config) return ndk.test.result.Skipped(self, message) - return self.make_result(self.test_case.run(self.device)) + return self.make_result(self.test_case.run(device), device) def build_tests(ndk_dir, out_dir, clean, printer, config, test_filter): @@ -320,7 +321,7 @@ def enumerate_tests(test_dir, test_filter, config_filter): return tests -def clear_test_directory(device): +def clear_test_directory(_worker_data, device): print('Clearing test directory on {}.'.format(device)) cmd = ['rm', '-r', DEVICE_TEST_BASE_DIR] logger().info('%s: shell_nocheck "%s"', device.name, cmd) @@ -345,7 +346,8 @@ def adb_has_feature(feature): return feature in features -def push_tests_to_device(src_dir, dest_dir, config, device, use_sync): +def push_tests_to_device(_worker_data, src_dir, dest_dir, config, device, + use_sync): print('Pushing {} tests to {}.'.format(config, device)) logger().info('%s: mkdir %s', device.name, dest_dir) device.shell_nocheck(['mkdir', dest_dir]) @@ -414,7 +416,7 @@ def asan_device_setup(ndk_path, device): device, out)) -def setup_asan_for_device(ndk_path, device): +def setup_asan_for_device(_worker_data, ndk_path, device): print('Performing ASAN setup for {}'.format(device)) disable_verity_and_wait_for_reboot(device) asan_device_setup(ndk_path, device) @@ -439,8 +441,9 @@ def perform_asan_setup(workqueue, ndk_path, groups_for_config): workqueue.get_result() -def run_test(test): - return test.run() +def run_test(worker_data, test): + device = worker_data[0] + return test.run(device) def print_test_stats(test_groups): @@ -484,7 +487,7 @@ def match_configs_to_device_groups(fleet, configs): return groups_for_config -def create_test_runs(test_groups, groups_for_config): +def pair_test_runs(test_groups, groups_for_config): """Creates a TestRun object for each device/test case pairing.""" test_runs = [] for config, test_cases in test_groups.items(): @@ -492,9 +495,7 @@ def create_test_runs(test_groups, groups_for_config): continue for group in groups_for_config[config]: - for shard_idx, device in enumerate(group.devices): - sharded_tests = test_cases[shard_idx::len(group.devices)] - test_runs.extend([TestRun(tc, device) for tc in sharded_tests]) + test_runs.extend([TestRun(tc, group) for tc in test_cases]) return test_runs @@ -643,6 +644,49 @@ class ConfigFilter(object): return config_tuple in self.config_tuples +class ShardingWorkQueue(object): + def __init__(self, device_groups, procs_per_device): + self.result_queue = multiprocessing.Queue() + self.task_queues = {} + self.work_queues = [] + self.num_tasks = 0 + for group in device_groups: + self.task_queues[group] = multiprocessing.Queue() + for device in group.devices: + self.work_queues.append( + ndk.workqueue.WorkQueue( + procs_per_device, task_queue=self.task_queues[group], + result_queue=self.result_queue, worker_data=[device])) + + def add_task(self, group, func, *args, **kwargs): + self.task_queues[group].put( + ndk.workqueue.Task(func, args, kwargs)) + self.num_tasks += 1 + + def get_result(self): + """Gets a result from the queue, blocking until one is available.""" + result = self.result_queue.get() + if type(result) == ndk.workqueue.TaskError: + raise result + self.num_tasks -= 1 + return result + + def terminate(self): + for work_queue in self.work_queues: + work_queue.terminate() + for task_queue in self.task_queues.values(): + ndk.workqueue.flush_queue(task_queue) + ndk.workqueue.flush_queue(self.result_queue) + + def join(self): + for work_queue in self.work_queues: + work_queue.join() + + def finished(self): + """Returns True if all tasks have completed execution.""" + return self.num_tasks == 0 + + def main(): total_timer = ndk.timer.Timer() total_timer.start() @@ -745,25 +789,33 @@ def main(): asan_setup_timer = ndk.timer.Timer() with asan_setup_timer: perform_asan_setup(workqueue, args.ndk, groups_for_config) + finally: + workqueue.terminate() + workqueue.join() + + shard_queue = ShardingWorkQueue(fleet.get_unique_device_groups(), 8) + try: + # Need an input queue per device group, a single result queue, and a + # pool of threads per device. # Shuffle the test runs to distribute the load more evenly. These are # ordered by (build config, device, test), so most of the tests running # at any given point in time are all running on the same device. - test_runs = create_test_runs(test_groups, groups_for_config) + test_runs = pair_test_runs(test_groups, groups_for_config) random.shuffle(test_runs) test_run_timer = ndk.timer.Timer() with test_run_timer: - for test in test_runs: - workqueue.add_task(run_test, test) + for test_run in test_runs: + shard_queue.add_task(test_run.device_group, run_test, test_run) - wait_for_results(report, workqueue, printer) - restart_flaky_tests(report, workqueue) - wait_for_results(report, workqueue, printer) + wait_for_results(report, shard_queue, printer) + restart_flaky_tests(report, shard_queue) + wait_for_results(report, shard_queue, printer) printer.print_summary(report) finally: - workqueue.terminate() - workqueue.join() + shard_queue.terminate() + shard_queue.join() total_timer.finish() diff --git a/tests/testlib.py b/tests/testlib.py index 0bf4646d3..1457141ee 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -256,10 +256,11 @@ def _fixup_negative_test(result): return result -def _run_test(suite, test, obj_dir, dist_dir, test_filters): +def _run_test(_worker_data, suite, test, obj_dir, dist_dir, test_filters): """Runs a given test according to the given filters. Args: + _worker_data: Data identifying the worker process. suite: Name of the test suite the test belongs to. test: The test to be run. obj_dir: Out directory for intermediate build artifacts. |