aboutsummaryrefslogtreecommitdiff
path: root/public/actions/common_operations.py
blob: 2b64ec5474f46dca9d6a95982a36e513e016bf73 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#!/usr/bin/env python
#
# Copyright 2018 - The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common operations between managing GCE and Cuttlefish devices.

This module provides the common operations between managing GCE (device_driver)
and Cuttlefish (create_cuttlefish_action) devices. Should not be called
directly.
"""

import logging
import os

from acloud import errors
from acloud.public import avd
from acloud.public import report
from acloud.internal import constants
from acloud.internal.lib import utils
from acloud.internal.lib.adb_tools import AdbTools


logger = logging.getLogger(__name__)
_GCE_QUOTA_ERROR_KEYWORDS = [
    "Quota exceeded for quota",
    "ZONE_RESOURCE_POOL_EXHAUSTED",
    "ZONE_RESOURCE_POOL_EXHAUSTED_WITH_DETAILS"]
_DICT_ERROR_TYPE = {
    constants.STAGE_INIT: constants.ACLOUD_INIT_ERROR,
    constants.STAGE_GCE: constants.ACLOUD_CREATE_GCE_ERROR,
    constants.STAGE_SSH_CONNECT: constants.ACLOUD_SSH_CONNECT_ERROR,
    constants.STAGE_ARTIFACT: constants.ACLOUD_DOWNLOAD_ARTIFACT_ERROR,
    constants.STAGE_BOOT_UP: constants.ACLOUD_BOOT_UP_ERROR,
}


def CreateSshKeyPairIfNecessary(cfg):
    """Create ssh key pair if necessary.

    Args:
        cfg: An Acloudconfig instance.

    Raises:
        error.DriverError: If it falls into an unexpected condition.
    """
    if not cfg.ssh_public_key_path:
        logger.warning(
            "ssh_public_key_path is not specified in acloud config. "
            "Project-wide public key will "
            "be used when creating AVD instances. "
            "Please ensure you have the correct private half of "
            "a project-wide public key if you want to ssh into the "
            "instances after creation.")
    elif cfg.ssh_public_key_path and not cfg.ssh_private_key_path:
        logger.warning(
            "Only ssh_public_key_path is specified in acloud config, "
            "but ssh_private_key_path is missing. "
            "Please ensure you have the correct private half "
            "if you want to ssh into the instances after creation.")
    elif cfg.ssh_public_key_path and cfg.ssh_private_key_path:
        utils.CreateSshKeyPairIfNotExist(cfg.ssh_private_key_path,
                                         cfg.ssh_public_key_path)
    else:
        # Should never reach here.
        raise errors.DriverError(
            "Unexpected error in CreateSshKeyPairIfNecessary")


class DevicePool:
    """A class that manages a pool of virtual devices.

    Attributes:
        devices: A list of devices in the pool.
    """

    def __init__(self, device_factory, devices=None):
        """Constructs a new DevicePool.

        Args:
            device_factory: A device factory capable of producing a goldfish or
                cuttlefish device. The device factory must expose an attribute with
                the credentials that can be used to retrieve information from the
                constructed device.
            devices: List of devices managed by this pool.
        """
        self._devices = devices or []
        self._device_factory = device_factory
        self._compute_client = device_factory.GetComputeClient()

    def CreateDevices(self, num):
        """Creates |num| devices for given build_target and build_id.

        Args:
            num: Number of devices to create.
        """
        # Create host instances for cuttlefish/goldfish device.
        # Currently one instance supports only 1 device.
        for _ in range(num):
            instance = self._device_factory.CreateInstance()
            ip = self._compute_client.GetInstanceIP(instance)
            time_info = self._compute_client.execution_time if hasattr(
                self._compute_client, "execution_time") else {}
            stage = self._compute_client.stage if hasattr(
                self._compute_client, "stage") else 0
            openwrt = self._compute_client.openwrt if hasattr(
                self._compute_client, "openwrt") else False
            self.devices.append(
                avd.AndroidVirtualDevice(ip=ip, instance_name=instance,
                                         time_info=time_info, stage=stage,
                                         openwrt=openwrt))

    @utils.TimeExecute(function_description="Waiting for AVD(s) to boot up",
                       result_evaluator=utils.BootEvaluator)
    def WaitForBoot(self, boot_timeout_secs):
        """Waits for all devices to boot up.

        Args:
            boot_timeout_secs: Integer, the maximum time in seconds used to
                               wait for the AVD to boot.

        Returns:
            A dictionary that contains all the failures.
            The key is the name of the instance that fails to boot,
            and the value is an errors.DeviceBootError object.
        """
        failures = {}
        for device in self._devices:
            try:
                self._compute_client.WaitForBoot(device.instance_name, boot_timeout_secs)
            except errors.DeviceBootError as e:
                failures[device.instance_name] = e
        return failures

    def UpdateReport(self, reporter):
        """Update report from compute client.

        Args:
            reporter: Report object.
        """
        reporter.UpdateData(self._compute_client.dict_report)

    def CollectSerialPortLogs(self, output_file,
                              port=constants.DEFAULT_SERIAL_PORT):
        """Tar the instance serial logs into specified output_file.

        Args:
            output_file: String, the output tar file path
            port: The serial port number to be collected
        """
        # For emulator, the serial log is the virtual host serial log.
        # For GCE AVD device, the serial log is the AVD device serial log.
        with utils.TempDir() as tempdir:
            src_dict = {}
            for device in self._devices:
                logger.info("Store instance %s serial port %s output to %s",
                            device.instance_name, port, output_file)
                serial_log = self._compute_client.GetSerialPortOutput(
                    instance=device.instance_name, port=port)
                file_name = "%s_serial_%s.log" % (device.instance_name, port)
                file_path = os.path.join(tempdir, file_name)
                src_dict[file_path] = file_name
                with open(file_path, "w") as f:
                    f.write(serial_log.encode("utf-8"))
            utils.MakeTarFile(src_dict, output_file)

    def SetDeviceBuildInfo(self):
        """Add devices build info."""
        for device in self._devices:
            device.build_info = self._device_factory.GetBuildInfoDict()

    @property
    def devices(self):
        """Returns a list of devices in the pool.

        Returns:
            A list of devices in the pool.
        """
        return self._devices

def _GetErrorType(error):
    """Get proper error type from the exception error.

    Args:
        error: errors object.

    Returns:
        String of error type. e.g. "ACLOUD_BOOT_UP_ERROR".
    """
    if isinstance(error, errors.CheckGCEZonesQuotaError):
        return constants.GCE_QUOTA_ERROR
    if isinstance(error, errors.DownloadArtifactError):
        return constants.ACLOUD_DOWNLOAD_ARTIFACT_ERROR
    if isinstance(error, errors.DeviceConnectionError):
        return constants.ACLOUD_SSH_CONNECT_ERROR
    for keyword in _GCE_QUOTA_ERROR_KEYWORDS:
        if keyword in str(error):
            return constants.GCE_QUOTA_ERROR
    return constants.ACLOUD_UNKNOWN_ERROR

def _GetAdbPort(avd_type, base_instance_num):
    """Get Adb port according to avd_type and device offset.

    Args:
        avd_type: String, the AVD type(cuttlefish, goldfish...).
        base_instance_num: int, device offset.

    Returns:
        int, adb port.
    """
    if avd_type in utils.AVD_PORT_DICT:
        return utils.AVD_PORT_DICT[avd_type].adb_port + base_instance_num - 1
    return None

# pylint: disable=too-many-locals,unused-argument,too-many-branches
def CreateDevices(command, cfg, device_factory, num, avd_type,
                  report_internal_ip=False, autoconnect=False,
                  serial_log_file=None, client_adb_port=None,
                  boot_timeout_secs=None, unlock_screen=False,
                  wait_for_boot=True, connect_webrtc=False,
                  ssh_private_key_path=None,
                  ssh_user=constants.GCE_USER):
    """Create a set of devices using the given factory.

    Main jobs in create devices.
        1. Create GCE instance: Launch instance in GCP(Google Cloud Platform).
        2. Starting up AVD: Wait device boot up.

    Args:
        command: The name of the command, used for reporting.
        cfg: An AcloudConfig instance.
        device_factory: A factory capable of producing a single device.
        num: The number of devices to create.
        avd_type: String, the AVD type(cuttlefish, goldfish...).
        report_internal_ip: Boolean to report the internal ip instead of
                            external ip.
        serial_log_file: String, the file path to tar the serial logs.
        autoconnect: Boolean, whether to auto connect to device.
        client_adb_port: Integer, Specify port for adb forwarding.
        boot_timeout_secs: Integer, boot timeout secs.
        unlock_screen: Boolean, whether to unlock screen after invoke vnc client.
        wait_for_boot: Boolean, True to check serial log include boot up
                       message.
        connect_webrtc: Boolean, whether to auto connect webrtc to device.
        ssh_private_key_path: String, the private key for SSH tunneling.
        ssh_user: String, the user name for SSH tunneling.

    Raises:
        errors: Create instance fail.

    Returns:
        A Report instance.
    """
    reporter = report.Report(command=command)
    try:
        CreateSshKeyPairIfNecessary(cfg)
        device_pool = DevicePool(device_factory)
        device_pool.CreateDevices(num)
        device_pool.SetDeviceBuildInfo()
        if wait_for_boot:
            failures = device_pool.WaitForBoot(boot_timeout_secs)
        else:
            failures = device_factory.GetFailures()

        if failures:
            reporter.SetStatus(report.Status.BOOT_FAIL)
        else:
            reporter.SetStatus(report.Status.SUCCESS)

        # Collect logs
        logs = device_factory.GetLogs()
        if serial_log_file:
            device_pool.CollectSerialPortLogs(
                serial_log_file, port=constants.DEFAULT_SERIAL_PORT)

        device_pool.UpdateReport(reporter)
        # Write result to report.
        for device in device_pool.devices:
            ip = (device.ip.internal if report_internal_ip
                  else device.ip.external)
            base_instance_num = 1
            if constants.BASE_INSTANCE_NUM in device_pool._compute_client.dict_report:
                base_instance_num = device_pool._compute_client.dict_report[constants.BASE_INSTANCE_NUM]
            adb_port = _GetAdbPort(
                avd_type,
                base_instance_num
            )
            device_dict = {
                "ip": ip + (":" + str(adb_port) if adb_port else ""),
                "instance_name": device.instance_name
            }
            if device.build_info:
                device_dict.update(device.build_info)
            if device.time_info:
                device_dict.update(device.time_info)
            if device.openwrt:
                device_dict.update(device_factory.GetOpenWrtInfoDict())
            if autoconnect and reporter.status == report.Status.SUCCESS:
                forwarded_ports = utils.AutoConnect(
                    ip_addr=ip,
                    rsa_key_file=(ssh_private_key_path or
                                  cfg.ssh_private_key_path),
                    target_vnc_port=utils.AVD_PORT_DICT[avd_type].vnc_port,
                    target_adb_port=adb_port,
                    ssh_user=ssh_user,
                    client_adb_port=client_adb_port,
                    extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel)
                device_dict[constants.VNC_PORT] = forwarded_ports.vnc_port
                device_dict[constants.ADB_PORT] = forwarded_ports.adb_port
                device_dict[constants.DEVICE_SERIAL] = (
                    constants.REMOTE_INSTANCE_ADB_SERIAL %
                    forwarded_ports.adb_port)
                if unlock_screen:
                    AdbTools(forwarded_ports.adb_port).AutoUnlockScreen()
            if connect_webrtc and reporter.status == report.Status.SUCCESS:
                webrtc_local_port = utils.PickFreePort()
                device_dict[constants.WEBRTC_PORT] = webrtc_local_port
                utils.EstablishWebRTCSshTunnel(
                    ip_addr=ip,
                    webrtc_local_port=webrtc_local_port,
                    rsa_key_file=(ssh_private_key_path or
                                  cfg.ssh_private_key_path),
                    ssh_user=ssh_user,
                    extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel)
            if device.instance_name in logs:
                device_dict[constants.LOGS] = logs[device.instance_name]
            if device.instance_name in failures:
                reporter.SetErrorType(constants.ACLOUD_BOOT_UP_ERROR)
                if device.stage:
                    reporter.SetErrorType(_DICT_ERROR_TYPE[device.stage])
                reporter.AddData(key="devices_failing_boot", value=device_dict)
                reporter.AddError(str(failures[device.instance_name]))
            else:
                reporter.AddData(key="devices", value=device_dict)
    except (errors.DriverError, errors.CheckGCEZonesQuotaError) as e:
        reporter.SetErrorType(_GetErrorType(e))
        reporter.AddError(str(e))
        reporter.SetStatus(report.Status.FAIL)
    return reporter