aboutsummaryrefslogtreecommitdiff
path: root/pw_transfer/integration_test/python_client.py
blob: 2474cf1215f2144f8984e08d0e4f866fcc6aed5a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python3
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Python client for pw_transfer integration test."""

import logging
import socket
import sys

from google.protobuf import text_format
from pw_hdlc.rpc import HdlcRpcClient, default_channels, SocketReader
from pw_status import Status
import pw_transfer
from pigweed.pw_transfer import transfer_pb2
from pigweed.pw_transfer.integration_test import config_pb2

_LOG = logging.getLogger('pw_transfer_integration_test_python_client')
_LOG.level = logging.DEBUG
_LOG.addHandler(logging.StreamHandler(sys.stdout))

HOSTNAME: str = "localhost"


def _perform_transfer_action(
    action: config_pb2.TransferAction, transfer_manager: pw_transfer.Manager
) -> bool:
    """Performs the transfer action and returns Truen on success."""
    protocol_version = pw_transfer.ProtocolVersion(int(action.protocol_version))

    # Default to the latest protocol version if none is specified.
    if protocol_version == pw_transfer.ProtocolVersion.UNKNOWN:
        protocol_version = pw_transfer.ProtocolVersion.LATEST

    if (
        action.transfer_type
        == config_pb2.TransferAction.TransferType.WRITE_TO_SERVER
    ):
        try:
            with open(action.file_path, 'rb') as f:
                data = f.read()
        except:
            _LOG.critical("Failed to read input file '%s'", action.file_path)
            return False

        try:
            transfer_manager.write(
                action.resource_id,
                data,
                protocol_version=protocol_version,
            )
        except pw_transfer.client.Error as e:
            if e.status != Status(action.expected_status):
                _LOG.exception(
                    "Unexpected error encountered during write transfer"
                )
                return False
        except:
            _LOG.exception("Transfer (write to server) failed")
            return False
    elif (
        action.transfer_type
        == config_pb2.TransferAction.TransferType.READ_FROM_SERVER
    ):
        try:
            data = transfer_manager.read(
                action.resource_id,
                protocol_version=protocol_version,
            )
        except pw_transfer.client.Error as e:
            if e.status != Status(action.expected_status):
                _LOG.exception(
                    "Unexpected error encountered during read transfer"
                )
                return False
            return True
        except:
            _LOG.exception("Transfer (read from server) failed")
            return False

        try:
            with open(action.file_path, 'wb') as f:
                f.write(data)
        except:
            _LOG.critical("Failed to write output file '%s'", action.file_path)
            return False
    else:
        _LOG.critical("Unknown transfer type: %d", action.transfer_type)
        return False
    return True


def _main() -> int:
    if len(sys.argv) != 2:
        _LOG.critical("Usage: PORT")
        return 1

    # The port is passed via the command line.
    try:
        port = int(sys.argv[1])
    except:
        _LOG.critical("Invalid port specified.")
        return 1

    # Load the config from stdin.
    try:
        text_config = sys.stdin.buffer.read()
        config = text_format.Parse(text_config, config_pb2.ClientConfig())
    except Exception as e:
        _LOG.critical("Failed to parse config file from stdin: %s", e)
        return 1

    # Open a connection to the server.
    try:
        rpc_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        rpc_socket.connect((HOSTNAME, port))
    except:
        _LOG.critical("Failed to connect to server at %s:%d", HOSTNAME, port)
        return 1

    # Initialize an RPC client using a socket reader and set up the
    # pw_transfer manager.
    reader = SocketReader(rpc_socket, 4096)
    with reader:
        rpc_client = HdlcRpcClient(
            reader,
            [transfer_pb2],
            default_channels(lambda data: rpc_socket.sendall(data)),
            lambda data: _LOG.info("%s", str(data)),
        )
        with rpc_client:
            transfer_service = rpc_client.rpcs().pw.transfer.Transfer
            transfer_manager = pw_transfer.Manager(
                transfer_service,
                default_response_timeout_s=config.chunk_timeout_ms / 1000,
                initial_response_timeout_s=config.initial_chunk_timeout_ms
                / 1000,
                max_retries=config.max_retries,
                max_lifetime_retries=config.max_lifetime_retries,
                default_protocol_version=pw_transfer.ProtocolVersion.LATEST,
            )

            transfer_logger = logging.getLogger('pw_transfer')
            transfer_logger.setLevel(logging.DEBUG)
            transfer_logger.addHandler(logging.StreamHandler(sys.stdout))

            # Perform the requested transfer actions.
            for action in config.transfer_actions:
                if not _perform_transfer_action(action, transfer_manager):
                    return 1

    _LOG.info("All transfers completed successfully")
    return 0


if __name__ == '__main__':
    sys.exit(_main())