aboutsummaryrefslogtreecommitdiff
path: root/pw_unit_test/py/pw_unit_test
diff options
context:
space:
mode:
Diffstat (limited to 'pw_unit_test/py/pw_unit_test')
-rw-r--r--pw_unit_test/py/pw_unit_test/rpc.py87
-rw-r--r--pw_unit_test/py/pw_unit_test/serial_test_runner.py196
-rw-r--r--pw_unit_test/py/pw_unit_test/test_runner.py54
3 files changed, 293 insertions, 44 deletions
diff --git a/pw_unit_test/py/pw_unit_test/rpc.py b/pw_unit_test/py/pw_unit_test/rpc.py
index ac1c56a2f..459071ec0 100644
--- a/pw_unit_test/py/pw_unit_test/rpc.py
+++ b/pw_unit_test/py/pw_unit_test/rpc.py
@@ -17,7 +17,7 @@ import enum
import abc
from dataclasses import dataclass
import logging
-from typing import Iterable
+from typing import Iterable, List, Tuple
from pw_rpc.client import Services
from pw_rpc.callback_client import OptionalTimeout, UseDefault
@@ -133,13 +133,28 @@ class LoggingEventHandler(EventHandler):
log(' Actual: %s', expectation.evaluated_expression)
+@dataclass(frozen=True)
+class TestRecord:
+ """Class for recording test results."""
+
+ passing_tests: Tuple[TestCase, ...]
+ failing_tests: Tuple[TestCase, ...]
+ disabled_tests: Tuple[TestCase, ...]
+
+ def all_tests_passed(self) -> bool:
+ return not self.failing_tests
+
+ def __bool__(self) -> bool:
+ return self.all_tests_passed()
+
+
def run_tests(
rpcs: Services,
report_passed_expectations: bool = False,
test_suites: Iterable[str] = (),
event_handlers: Iterable[EventHandler] = (LoggingEventHandler(),),
timeout_s: OptionalTimeout = UseDefault.VALUE,
-) -> bool:
+) -> TestRecord:
"""Runs unit tests on a device over Pigweed RPC.
Calls each of the provided event handlers as test events occur, and returns
@@ -174,39 +189,53 @@ def run_tests(
for event_handler in event_handlers:
event_handler.run_all_tests_start()
- all_tests_passed = False
+ passing_tests: List[TestCase] = []
+ failing_tests: List[TestCase] = []
+ disabled_tests: List[TestCase] = []
for response in test_responses:
- if response.HasField('test_case_start'):
- raw_test_case = response.test_case_start
- current_test_case = _test_case(raw_test_case)
-
- for event_handler in event_handlers:
- if response.HasField('test_run_start'):
+ if response.HasField('test_run_start'):
+ for event_handler in event_handlers:
event_handler.run_all_tests_start()
- elif response.HasField('test_run_end'):
+ elif response.HasField('test_run_end'):
+ for event_handler in event_handlers:
event_handler.run_all_tests_end(
response.test_run_end.passed, response.test_run_end.failed
)
- if response.test_run_end.failed == 0:
- all_tests_passed = True
- elif response.HasField('test_case_start'):
+ assert len(passing_tests) == response.test_run_end.passed
+ assert len(failing_tests) == response.test_run_end.failed
+ test_record = TestRecord(
+ passing_tests=tuple(passing_tests),
+ failing_tests=tuple(failing_tests),
+ disabled_tests=tuple(disabled_tests),
+ )
+ elif response.HasField('test_case_start'):
+ raw_test_case = response.test_case_start
+ current_test_case = _test_case(raw_test_case)
+ for event_handler in event_handlers:
event_handler.test_case_start(current_test_case)
- elif response.HasField('test_case_end'):
- result = TestCaseResult(response.test_case_end)
+ elif response.HasField('test_case_end'):
+ result = TestCaseResult(response.test_case_end)
+ for event_handler in event_handlers:
event_handler.test_case_end(current_test_case, result)
- elif response.HasField('test_case_disabled'):
- event_handler.test_case_disabled(
- _test_case(response.test_case_disabled)
- )
- elif response.HasField('test_case_expectation'):
- raw_expectation = response.test_case_expectation
- expectation = TestExpectation(
- raw_expectation.expression,
- raw_expectation.evaluated_expression,
- raw_expectation.line_number,
- raw_expectation.success,
- )
+ if result == TestCaseResult.SUCCESS:
+ passing_tests.append(current_test_case)
+ else:
+ failing_tests.append(current_test_case)
+ elif response.HasField('test_case_disabled'):
+ raw_test_case = response.test_case_disabled
+ current_test_case = _test_case(raw_test_case)
+ for event_handler in event_handlers:
+ event_handler.test_case_disabled(current_test_case)
+ disabled_tests.append(current_test_case)
+ elif response.HasField('test_case_expectation'):
+ raw_expectation = response.test_case_expectation
+ expectation = TestExpectation(
+ raw_expectation.expression,
+ raw_expectation.evaluated_expression,
+ raw_expectation.line_number,
+ raw_expectation.success,
+ )
+ for event_handler in event_handlers:
event_handler.test_case_expect(current_test_case, expectation)
-
- return all_tests_passed
+ return test_record
diff --git a/pw_unit_test/py/pw_unit_test/serial_test_runner.py b/pw_unit_test/py/pw_unit_test/serial_test_runner.py
new file mode 100644
index 000000000..cdb879ba7
--- /dev/null
+++ b/pw_unit_test/py/pw_unit_test/serial_test_runner.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python3
+# Copyright 2023 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""This library facilitates automating unit tests on devices with serial ports.
+
+This library assumes that the on-device test runner will emit the test results
+as plain-text over a serial port, and tests will be triggered by a pre-defined
+input (DEFAULT_TEST_START_CHARACTER) over the same serial port that results
+are emitted from.
+"""
+
+import abc
+import logging
+from pathlib import Path
+
+import serial # type: ignore
+
+
+_LOG = logging.getLogger("serial_test_runner")
+
+# Verification of test pass/failure depends on these strings. If the formatting
+# or output of the simple_printing_event_handler changes, this may need to be
+# updated.
+_TESTS_STARTING_STRING = b'[==========] Running all tests.'
+_TESTS_DONE_STRING = b'[==========] Done running all tests.'
+_TEST_FAILURE_STRING = b'[ FAILED ]'
+
+# Character used to trigger test start.
+DEFAULT_TEST_START_CHARACTER = ' '.encode('utf-8')
+
+
+class FlashingFailure(Exception):
+ """A simple exception to be raised when flashing fails."""
+
+
+class TestingFailure(Exception):
+ """A simple exception to be raised when a testing step fails."""
+
+
+class DeviceNotFound(Exception):
+ """A simple exception to be raised when unable to connect to a device."""
+
+
+class SerialTestingDevice(abc.ABC):
+ """A device that supports automated testing via parsing serial output."""
+
+ @abc.abstractmethod
+ def load_binary(self, binary: Path) -> None:
+ """Flashes the specified binary to this device.
+
+ Raises:
+ DeviceNotFound: This device is no longer available.
+ FlashingFailure: The binary could not be flashed.
+ """
+
+ @abc.abstractmethod
+ def serial_port(self) -> str:
+ """Returns the name of the com port this device is enumerated on.
+
+ Raises:
+ DeviceNotFound: This device is no longer available.
+ """
+
+ @abc.abstractmethod
+ def baud_rate(self) -> int:
+ """Returns the baud rate to use when connecting to this device.
+
+ Raises:
+ DeviceNotFound: This device is no longer available.
+ """
+
+
+def _log_subprocess_output(level, output: bytes, logger: logging.Logger):
+ """Logs subprocess output line-by-line."""
+
+ lines = output.decode('utf-8', errors='replace').splitlines()
+ for line in lines:
+ logger.log(level, line)
+
+
+def trigger_test_run(
+ port: str,
+ baud_rate: int,
+ test_timeout: float,
+ trigger_data: bytes = DEFAULT_TEST_START_CHARACTER,
+) -> bytes:
+ """Triggers a test run, and returns captured test results."""
+
+ serial_data = bytearray()
+ device = serial.Serial(baudrate=baud_rate, port=port, timeout=test_timeout)
+ if not device.is_open:
+ raise TestingFailure('Failed to open device')
+
+ # Flush input buffer and trigger the test start.
+ device.reset_input_buffer()
+ device.write(trigger_data)
+
+ # Block and wait for the first byte.
+ serial_data += device.read()
+ if not serial_data:
+ raise TestingFailure('Device not producing output')
+
+ # Read with a reasonable timeout until we stop getting characters.
+ while True:
+ bytes_read = device.readline()
+ if not bytes_read:
+ break
+ serial_data += bytes_read
+ if serial_data.rfind(_TESTS_DONE_STRING) != -1:
+ # Set to much more aggressive timeout since the last one or two
+ # lines should print out immediately. (one line if all fails or all
+ # passes, two lines if mixed.)
+ device.timeout = 0.01
+
+ # Remove carriage returns.
+ serial_data = serial_data.replace(b"\r", b"")
+
+ # Try to trim captured results to only contain most recent test run.
+ test_start_index = serial_data.rfind(_TESTS_STARTING_STRING)
+ return (
+ serial_data
+ if test_start_index == -1
+ else serial_data[test_start_index:]
+ )
+
+
+def handle_test_results(
+ test_output: bytes, logger: logging.Logger = _LOG
+) -> None:
+ """Parses test output to determine whether tests passed or failed.
+
+ Raises:
+ TestingFailure if any tests fail or if test results are incomplete.
+ """
+
+ if test_output.find(_TESTS_STARTING_STRING) == -1:
+ raise TestingFailure('Failed to find test start')
+
+ if test_output.rfind(_TESTS_DONE_STRING) == -1:
+ _log_subprocess_output(logging.INFO, test_output, logger)
+ raise TestingFailure('Tests did not complete')
+
+ if test_output.rfind(_TEST_FAILURE_STRING) != -1:
+ _log_subprocess_output(logging.INFO, test_output, logger)
+ raise TestingFailure('Test suite had one or more failures')
+
+ _log_subprocess_output(logging.DEBUG, test_output, logger)
+
+ logger.info('Test passed!')
+
+
+def run_device_test(
+ device: SerialTestingDevice,
+ binary: Path,
+ test_timeout: float,
+ logger: logging.Logger = _LOG,
+) -> bool:
+ """Runs tests on a device.
+
+ When a unit test run fails, results will be logged as an error.
+
+ Args:
+ device: The device to run tests on.
+ binary: The binary containing tests that will be flashed to the device.
+ test_timeout: If the device stops producing output longer than this
+ timeout, the test will be considered stuck and the test will be aborted.
+
+ Returns:
+ True if all tests passed.
+ """
+
+ logger.info('Flashing binary to device')
+ device.load_binary(binary)
+ try:
+ logger.info('Running test')
+ test_output = trigger_test_run(
+ device.serial_port(), device.baud_rate(), test_timeout
+ )
+ if test_output:
+ handle_test_results(test_output, logger)
+ except TestingFailure as err:
+ logger.error(err)
+ return False
+
+ return True
diff --git a/pw_unit_test/py/pw_unit_test/test_runner.py b/pw_unit_test/py/pw_unit_test/test_runner.py
index a7b06ac5e..1869ce026 100644
--- a/pw_unit_test/py/pw_unit_test/test_runner.py
+++ b/pw_unit_test/py/pw_unit_test/test_runner.py
@@ -65,12 +65,13 @@ def register_arguments(parser: argparse.ArgumentParser) -> None:
'-m', '--timeout', type=float, help='Timeout for test runner in seconds'
)
parser.add_argument(
- '--coverage-profraw',
- type=str,
- help='The name of the coverage profraw file to produce with the'
- ' coverage information from this test. Only provide this if the test'
- ' should be run for coverage and is properly instrumented.',
+ '-e',
+ '--env',
+ nargs='*',
+ help='Environment variables to set for the test. These should be of the'
+ ' form `var_name=value`.',
)
+
parser.add_argument(
'runner_args', nargs="*", help='Arguments to forward to the test runner'
)
@@ -160,15 +161,17 @@ class TestRunner:
executable: str,
args: Sequence[str],
tests: Iterable[Test],
- coverage_profraw: Optional[str] = None,
+ env: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
+ verbose: bool = False,
) -> None:
self._executable: str = executable
self._args: Sequence[str] = args
self._tests: List[Test] = list(tests)
- self._coverage_profraw = coverage_profraw
+ self._env: Dict[str, str] = env or {}
self._timeout = timeout
self._result_sink: Optional[Dict[str, str]] = None
+ self.verbose = verbose
# Access go/result-sink, if available.
ctx_path = Path(os.environ.get("LUCI_CONTEXT", ''))
@@ -201,11 +204,11 @@ class TestRunner:
test.start_time = datetime.datetime.now(datetime.timezone.utc)
start_time = time.monotonic()
try:
- env = {}
- if self._coverage_profraw is not None:
- env['LLVM_PROFILE_FILE'] = str(Path(self._coverage_profraw))
process = await pw_cli.process.run_async(
- *command, env=env, timeout=self._timeout
+ *command,
+ env=self._env,
+ timeout=self._timeout,
+ log_output=self.verbose,
)
except subprocess.CalledProcessError as err:
_LOG.error(err)
@@ -287,7 +290,7 @@ class TestRunner:
# Need to decode the bytes back to ASCII or they will not be
# encodable by json.dumps.
#
- # TODO(b/248349219): Instead of stripping the ANSI color
+ # TODO: b/248349219 - Instead of stripping the ANSI color
# codes, convert them to HTML.
"contents": base64.b64encode(
_strip_ansi(process.output)
@@ -462,14 +465,34 @@ def tests_from_paths(paths: Sequence[str]) -> List[Test]:
return tests
+def parse_env(env: Sequence[str]) -> Dict[str, str]:
+ """Returns a dictionary of environment names and values.
+
+ Args:
+ env: List of strings of the form "key=val".
+
+ Raises:
+ ValueError if `env` is malformed.
+ """
+ envvars = {}
+ if env:
+ for envvar in env:
+ parts = envvar.split('=')
+ if len(parts) != 2:
+ raise ValueError(f'malformed environment variable: {envvar}')
+ envvars[parts[0]] = parts[1]
+ return envvars
+
+
async def find_and_run_tests(
root: str,
runner: str,
runner_args: Sequence[str] = (),
- coverage_profraw: Optional[str] = None,
+ env: Sequence[str] = (),
timeout: Optional[float] = None,
group: Optional[Sequence[str]] = None,
test: Optional[Sequence[str]] = None,
+ verbose: bool = False,
) -> int:
"""Runs some unit tests."""
@@ -478,8 +501,10 @@ async def find_and_run_tests(
else:
tests = tests_from_groups(group, root)
+ envvars = parse_env(env)
+
test_runner = TestRunner(
- runner, runner_args, tests, coverage_profraw, timeout
+ runner, runner_args, tests, envvars, timeout, verbose
)
await test_runner.run_tests()
@@ -499,7 +524,6 @@ def main() -> int:
)
args_as_dict = dict(vars(parser.parse_args()))
- del args_as_dict['verbose']
return asyncio.run(find_and_run_tests(**args_as_dict))