aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcheckbuild.py2
-rw-r--r--ndk/test/devices.py6
-rw-r--r--ndk/test/test_workqueue.py2
-rw-r--r--ndk/workqueue.py51
-rwxr-xr-xrun_tests.py108
-rw-r--r--tests/testlib.py3
6 files changed, 127 insertions, 45 deletions
diff --git a/checkbuild.py b/checkbuild.py
index f15eac42c..0ccaf41a8 100755
--- a/checkbuild.py
+++ b/checkbuild.py
@@ -1277,7 +1277,7 @@ class SourceProperties(ndk.builds.Module):
])
-def launch_build(module, out_dir, dist_dir, args, log_dir):
+def launch_build(_worker_data, module, out_dir, dist_dir, args, log_dir):
log_path = os.path.join(log_dir, module.name) + '.log'
tee = subprocess.Popen(["tee", log_path], stdin=subprocess.PIPE)
try:
diff --git a/ndk/test/devices.py b/ndk/test/devices.py
index 37e74ab4b..5eb7e5ac5 100644
--- a/ndk/test/devices.py
+++ b/ndk/test/devices.py
@@ -322,6 +322,10 @@ class DeviceFleet(object):
return self.devices[version].keys()
+def create_device(_worker_data, serial, precache):
+ return Device(serial, precache)
+
+
def get_all_attached_devices(workqueue):
"""Returns a list of all connected devices."""
if distutils.spawn.find_executable('adb') is None:
@@ -352,7 +356,7 @@ def get_all_attached_devices(workqueue):
# Caching all the device details via getprop can actually take quite a
# bit of time. Do it in parallel to minimize the cost.
- workqueue.add_task(Device, serial, True)
+ workqueue.add_task(create_device, serial, True)
devices = []
while not workqueue.finished():
diff --git a/ndk/test/test_workqueue.py b/ndk/test/test_workqueue.py
index ba31afa80..8778aa8c2 100644
--- a/ndk/test/test_workqueue.py
+++ b/ndk/test/test_workqueue.py
@@ -24,7 +24,7 @@ import unittest
import ndk.workqueue
-def put(i):
+def put(_worker_data, i):
"""Returns an the passed argument."""
return i
diff --git a/ndk/workqueue.py b/ndk/workqueue.py
index 6735f5711..d163926f4 100644
--- a/ndk/workqueue.py
+++ b/ndk/workqueue.py
@@ -35,7 +35,7 @@ def worker_sigterm_handler(_signum, _frame):
sys.exit()
-def _flush_queue(queue):
+def flush_queue(queue):
"""Flushes all pending items from a Queue."""
try:
while True:
@@ -58,7 +58,7 @@ class TaskError(Exception):
super(TaskError, self).__init__(trace)
-def worker_main(task_queue, result_queue):
+def worker_main(worker_data, task_queue, result_queue):
"""Main loop for worker processes.
Args:
@@ -72,7 +72,7 @@ def worker_main(task_queue, result_queue):
logger().debug('worker %d waiting for work', os.getpid())
task = task_queue.get()
logger().debug('worker %d running task', os.getpid())
- result = task.run()
+ result = task.run(worker_data)
logger().debug('worker %d putting result', os.getpid())
result_queue.put(result)
except SystemExit:
@@ -104,9 +104,9 @@ class Task(object):
self.args = args
self.kwargs = kwargs
- def run(self):
+ def run(self, worker_data):
"""Invokes the task."""
- return self.func(*self.args, **self.kwargs)
+ return self.func(worker_data, *self.args, **self.kwargs)
class ProcessPoolWorkQueue(object):
@@ -114,7 +114,8 @@ class ProcessPoolWorkQueue(object):
join_timeout = 8 # Timeout for join before trying SIGKILL.
- def __init__(self, num_workers=multiprocessing.cpu_count()):
+ def __init__(self, num_workers=multiprocessing.cpu_count(),
+ task_queue=None, result_queue=None, worker_data=None):
"""Creates a WorkQueue.
Worker threads are spawned immediately and remain live until both
@@ -122,6 +123,14 @@ class ProcessPoolWorkQueue(object):
Args:
num_workers: Number of worker processes to spawn.
+ task_queue: multiprocessing.Queue for tasks. Allows multiple work
+ queues to share a single task queue. If None, the work queue
+ creates its own.
+ result_queue: multiprocessing.Queue for results. Allows multiple
+ work queues to share a single result queue. If None, the work
+ queue creates its own.
+ worker_data: Data to be passed to every task run by this work
+ queue.
"""
if sys.platform == 'win32':
# TODO(danalbert): Port ProcessPoolWorkQueue to Windows.
@@ -129,8 +138,20 @@ class ProcessPoolWorkQueue(object):
# groups, which are not supported on Windows.
raise NotImplementedError
- self.task_queue = multiprocessing.Queue()
- self.result_queue = multiprocessing.Queue()
+ self.task_queue = task_queue
+ self.owns_task_queue = False
+ if task_queue is None:
+ self.task_queue = multiprocessing.Queue()
+ self.owns_task_queue = True
+
+ self.result_queue = result_queue
+ self.owns_result_queue = False
+ if result_queue is None:
+ self.result_queue = multiprocessing.Queue()
+ self.owns_result_queue = True
+
+ self.worker_data = worker_data
+
self.workers = []
# multiprocessing.JoinableQueue's join isn't able to implement
# finished() because it doesn't come in a non-blocking flavor.
@@ -176,8 +197,10 @@ class ProcessPoolWorkQueue(object):
We call _flush after all workers have been terminated to ensure that we
can exit cleanly.
"""
- _flush_queue(self.task_queue)
- _flush_queue(self.result_queue)
+ if self.owns_task_queue:
+ flush_queue(self.task_queue)
+ if self.owns_result_queue:
+ flush_queue(self.result_queue)
def join(self):
"""Waits for all worker processes to exit."""
@@ -203,7 +226,8 @@ class ProcessPoolWorkQueue(object):
"""
for _ in range(num_workers):
worker = multiprocessing.Process(
- target=worker_main, args=(self.task_queue, self.result_queue))
+ target=worker_main,
+ args=(self.worker_data, self.task_queue, self.result_queue))
worker.start()
self.workers.append(worker)
@@ -214,9 +238,10 @@ class DummyWorkQueue(object):
Useful for debugging when trying to determine if an issue is being caused
by multiprocess specific behavior.
"""
- def __init__(self):
+ def __init__(self, worker_data=None):
"""Creates a SerialWorkQueue."""
self.task_queue = collections.deque()
+ self.worker_data = worker_data
def add_task(self, func, *args, **kwargs):
"""Queues up a new task for execution.
@@ -234,7 +259,7 @@ class DummyWorkQueue(object):
"""Executes a task and returns the result."""
task = self.task_queue.popleft()
try:
- return task.run()
+ return task.run(self.worker_data)
except:
trace = ''.join(traceback.format_exception(*sys.exc_info()))
raise TaskError(trace)
diff --git a/run_tests.py b/run_tests.py
index 48b67a70c..cbe59c48e 100755
--- a/run_tests.py
+++ b/run_tests.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import argparse
import json
import logging
+import multiprocessing
import os
import posixpath
import random
@@ -183,10 +184,10 @@ class LibcxxTestCase(TestCase):
class TestRun(object):
- """A test case mapped to the device it will run on."""
- def __init__(self, test_case, device):
+ """A test case mapped to the device group it will run on."""
+ def __init__(self, test_case, device_group):
self.test_case = test_case
- self.device = device
+ self.device_group = device_group
@property
def name(self):
@@ -200,17 +201,17 @@ class TestRun(object):
def config(self):
return self.test_case.config
- def make_result(self, adb_result_tuple):
+ def make_result(self, adb_result_tuple, device):
status, out, _ = adb_result_tuple
if status == 0:
result = ndk.test.result.Success(self)
else:
- out = '\n'.join([str(self.device), out])
+ out = '\n'.join([str(device), out])
result = ndk.test.result.Failure(self, out)
- return self.fixup_xfail(result)
+ return self.fixup_xfail(result, device)
- def fixup_xfail(self, result):
- config, bug = self.test_case.check_broken(self.device)
+ def fixup_xfail(self, result, device):
+ config, bug = self.test_case.check_broken(device)
if config is not None:
if result.failed():
return ndk.test.result.ExpectedFailure(self, config, bug)
@@ -219,12 +220,12 @@ class TestRun(object):
raise ValueError('Test result must have either failed or passed.')
return result
- def run(self):
- config = self.test_case.check_unsupported(self.device)
+ def run(self, device):
+ config = self.test_case.check_unsupported(device)
if config is not None:
message = 'test unsupported for {}'.format(config)
return ndk.test.result.Skipped(self, message)
- return self.make_result(self.test_case.run(self.device))
+ return self.make_result(self.test_case.run(device), device)
def build_tests(ndk_dir, out_dir, clean, printer, config, test_filter):
@@ -320,7 +321,7 @@ def enumerate_tests(test_dir, test_filter, config_filter):
return tests
-def clear_test_directory(device):
+def clear_test_directory(_worker_data, device):
print('Clearing test directory on {}.'.format(device))
cmd = ['rm', '-r', DEVICE_TEST_BASE_DIR]
logger().info('%s: shell_nocheck "%s"', device.name, cmd)
@@ -345,7 +346,8 @@ def adb_has_feature(feature):
return feature in features
-def push_tests_to_device(src_dir, dest_dir, config, device, use_sync):
+def push_tests_to_device(_worker_data, src_dir, dest_dir, config, device,
+ use_sync):
print('Pushing {} tests to {}.'.format(config, device))
logger().info('%s: mkdir %s', device.name, dest_dir)
device.shell_nocheck(['mkdir', dest_dir])
@@ -414,7 +416,7 @@ def asan_device_setup(ndk_path, device):
device, out))
-def setup_asan_for_device(ndk_path, device):
+def setup_asan_for_device(_worker_data, ndk_path, device):
print('Performing ASAN setup for {}'.format(device))
disable_verity_and_wait_for_reboot(device)
asan_device_setup(ndk_path, device)
@@ -439,8 +441,9 @@ def perform_asan_setup(workqueue, ndk_path, groups_for_config):
workqueue.get_result()
-def run_test(test):
- return test.run()
+def run_test(worker_data, test):
+ device = worker_data[0]
+ return test.run(device)
def print_test_stats(test_groups):
@@ -484,7 +487,7 @@ def match_configs_to_device_groups(fleet, configs):
return groups_for_config
-def create_test_runs(test_groups, groups_for_config):
+def pair_test_runs(test_groups, groups_for_config):
"""Creates a TestRun object for each device/test case pairing."""
test_runs = []
for config, test_cases in test_groups.items():
@@ -492,9 +495,7 @@ def create_test_runs(test_groups, groups_for_config):
continue
for group in groups_for_config[config]:
- for shard_idx, device in enumerate(group.devices):
- sharded_tests = test_cases[shard_idx::len(group.devices)]
- test_runs.extend([TestRun(tc, device) for tc in sharded_tests])
+ test_runs.extend([TestRun(tc, group) for tc in test_cases])
return test_runs
@@ -643,6 +644,49 @@ class ConfigFilter(object):
return config_tuple in self.config_tuples
+class ShardingWorkQueue(object):
+ def __init__(self, device_groups, procs_per_device):
+ self.result_queue = multiprocessing.Queue()
+ self.task_queues = {}
+ self.work_queues = []
+ self.num_tasks = 0
+ for group in device_groups:
+ self.task_queues[group] = multiprocessing.Queue()
+ for device in group.devices:
+ self.work_queues.append(
+ ndk.workqueue.WorkQueue(
+ procs_per_device, task_queue=self.task_queues[group],
+ result_queue=self.result_queue, worker_data=[device]))
+
+ def add_task(self, group, func, *args, **kwargs):
+ self.task_queues[group].put(
+ ndk.workqueue.Task(func, args, kwargs))
+ self.num_tasks += 1
+
+ def get_result(self):
+ """Gets a result from the queue, blocking until one is available."""
+ result = self.result_queue.get()
+ if type(result) == ndk.workqueue.TaskError:
+ raise result
+ self.num_tasks -= 1
+ return result
+
+ def terminate(self):
+ for work_queue in self.work_queues:
+ work_queue.terminate()
+ for task_queue in self.task_queues.values():
+ ndk.workqueue.flush_queue(task_queue)
+ ndk.workqueue.flush_queue(self.result_queue)
+
+ def join(self):
+ for work_queue in self.work_queues:
+ work_queue.join()
+
+ def finished(self):
+ """Returns True if all tasks have completed execution."""
+ return self.num_tasks == 0
+
+
def main():
total_timer = ndk.timer.Timer()
total_timer.start()
@@ -745,25 +789,33 @@ def main():
asan_setup_timer = ndk.timer.Timer()
with asan_setup_timer:
perform_asan_setup(workqueue, args.ndk, groups_for_config)
+ finally:
+ workqueue.terminate()
+ workqueue.join()
+
+ shard_queue = ShardingWorkQueue(fleet.get_unique_device_groups(), 8)
+ try:
+ # Need an input queue per device group, a single result queue, and a
+ # pool of threads per device.
# Shuffle the test runs to distribute the load more evenly. These are
# ordered by (build config, device, test), so most of the tests running
# at any given point in time are all running on the same device.
- test_runs = create_test_runs(test_groups, groups_for_config)
+ test_runs = pair_test_runs(test_groups, groups_for_config)
random.shuffle(test_runs)
test_run_timer = ndk.timer.Timer()
with test_run_timer:
- for test in test_runs:
- workqueue.add_task(run_test, test)
+ for test_run in test_runs:
+ shard_queue.add_task(test_run.device_group, run_test, test_run)
- wait_for_results(report, workqueue, printer)
- restart_flaky_tests(report, workqueue)
- wait_for_results(report, workqueue, printer)
+ wait_for_results(report, shard_queue, printer)
+ restart_flaky_tests(report, shard_queue)
+ wait_for_results(report, shard_queue, printer)
printer.print_summary(report)
finally:
- workqueue.terminate()
- workqueue.join()
+ shard_queue.terminate()
+ shard_queue.join()
total_timer.finish()
diff --git a/tests/testlib.py b/tests/testlib.py
index 0bf4646d3..1457141ee 100644
--- a/tests/testlib.py
+++ b/tests/testlib.py
@@ -256,10 +256,11 @@ def _fixup_negative_test(result):
return result
-def _run_test(suite, test, obj_dir, dist_dir, test_filters):
+def _run_test(_worker_data, suite, test, obj_dir, dist_dir, test_filters):
"""Runs a given test according to the given filters.
Args:
+ _worker_data: Data identifying the worker process.
suite: Name of the test suite the test belongs to.
test: The test to be run.
obj_dir: Out directory for intermediate build artifacts.