aboutsummaryrefslogtreecommitdiff
path: root/pw_unit_test/py/pw_unit_test/serial_test_runner.py
blob: cdb879ba73a1958c98f2f990408697a1abef454b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#!/usr/bin/env python3
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""This library facilitates automating unit tests on devices with serial ports.

This library assumes that the on-device test runner will emit the test results
as plain-text over a serial port, and tests will be triggered by a pre-defined
input (DEFAULT_TEST_START_CHARACTER) over the same serial port that results
are emitted from.
"""

import abc
import logging
from pathlib import Path

import serial  # type: ignore


_LOG = logging.getLogger("serial_test_runner")

# Verification of test pass/failure depends on these strings. If the formatting
# or output of the simple_printing_event_handler changes, this may need to be
# updated.
_TESTS_STARTING_STRING = b'[==========] Running all tests.'
_TESTS_DONE_STRING = b'[==========] Done running all tests.'
_TEST_FAILURE_STRING = b'[  FAILED  ]'

# Character used to trigger test start.
DEFAULT_TEST_START_CHARACTER = ' '.encode('utf-8')


class FlashingFailure(Exception):
    """A simple exception to be raised when flashing fails."""


class TestingFailure(Exception):
    """A simple exception to be raised when a testing step fails."""


class DeviceNotFound(Exception):
    """A simple exception to be raised when unable to connect to a device."""


class SerialTestingDevice(abc.ABC):
    """A device that supports automated testing via parsing serial output."""

    @abc.abstractmethod
    def load_binary(self, binary: Path) -> None:
        """Flashes the specified binary to this device.

        Raises:
          DeviceNotFound: This device is no longer available.
          FlashingFailure: The binary could not be flashed.
        """

    @abc.abstractmethod
    def serial_port(self) -> str:
        """Returns the name of the com port this device is enumerated on.

        Raises:
          DeviceNotFound: This device is no longer available.
        """

    @abc.abstractmethod
    def baud_rate(self) -> int:
        """Returns the baud rate to use when connecting to this device.

        Raises:
          DeviceNotFound: This device is no longer available.
        """


def _log_subprocess_output(level, output: bytes, logger: logging.Logger):
    """Logs subprocess output line-by-line."""

    lines = output.decode('utf-8', errors='replace').splitlines()
    for line in lines:
        logger.log(level, line)


def trigger_test_run(
    port: str,
    baud_rate: int,
    test_timeout: float,
    trigger_data: bytes = DEFAULT_TEST_START_CHARACTER,
) -> bytes:
    """Triggers a test run, and returns captured test results."""

    serial_data = bytearray()
    device = serial.Serial(baudrate=baud_rate, port=port, timeout=test_timeout)
    if not device.is_open:
        raise TestingFailure('Failed to open device')

    # Flush input buffer and trigger the test start.
    device.reset_input_buffer()
    device.write(trigger_data)

    # Block and wait for the first byte.
    serial_data += device.read()
    if not serial_data:
        raise TestingFailure('Device not producing output')

    # Read with a reasonable timeout until we stop getting characters.
    while True:
        bytes_read = device.readline()
        if not bytes_read:
            break
        serial_data += bytes_read
        if serial_data.rfind(_TESTS_DONE_STRING) != -1:
            # Set to much more aggressive timeout since the last one or two
            # lines should print out immediately. (one line if all fails or all
            # passes, two lines if mixed.)
            device.timeout = 0.01

    # Remove carriage returns.
    serial_data = serial_data.replace(b"\r", b"")

    # Try to trim captured results to only contain most recent test run.
    test_start_index = serial_data.rfind(_TESTS_STARTING_STRING)
    return (
        serial_data
        if test_start_index == -1
        else serial_data[test_start_index:]
    )


def handle_test_results(
    test_output: bytes, logger: logging.Logger = _LOG
) -> None:
    """Parses test output to determine whether tests passed or failed.

    Raises:
      TestingFailure if any tests fail or if test results are incomplete.
    """

    if test_output.find(_TESTS_STARTING_STRING) == -1:
        raise TestingFailure('Failed to find test start')

    if test_output.rfind(_TESTS_DONE_STRING) == -1:
        _log_subprocess_output(logging.INFO, test_output, logger)
        raise TestingFailure('Tests did not complete')

    if test_output.rfind(_TEST_FAILURE_STRING) != -1:
        _log_subprocess_output(logging.INFO, test_output, logger)
        raise TestingFailure('Test suite had one or more failures')

    _log_subprocess_output(logging.DEBUG, test_output, logger)

    logger.info('Test passed!')


def run_device_test(
    device: SerialTestingDevice,
    binary: Path,
    test_timeout: float,
    logger: logging.Logger = _LOG,
) -> bool:
    """Runs tests on a device.

    When a unit test run fails, results will be logged as an error.

    Args:
      device: The device to run tests on.
      binary: The binary containing tests that will be flashed to the device.
      test_timeout: If the device stops producing output longer than this
        timeout, the test will be considered stuck and the test will be aborted.

    Returns:
      True if all tests passed.
    """

    logger.info('Flashing binary to device')
    device.load_binary(binary)
    try:
        logger.info('Running test')
        test_output = trigger_test_run(
            device.serial_port(), device.baud_rate(), test_timeout
        )
        if test_output:
            handle_test_results(test_output, logger)
    except TestingFailure as err:
        logger.error(err)
        return False

    return True