diff options
Diffstat (limited to 'internal/lib/gcompute_client.py')
-rwxr-xr-x | internal/lib/gcompute_client.py | 173 |
1 files changed, 131 insertions, 42 deletions
diff --git a/internal/lib/gcompute_client.py b/internal/lib/gcompute_client.py index 1f44de05..759fb651 100755 --- a/internal/lib/gcompute_client.py +++ b/internal/lib/gcompute_client.py @@ -39,6 +39,11 @@ from acloud.internal.lib import utils logger = logging.getLogger(__name__) _MAX_RETRIES_ON_FINGERPRINT_CONFLICT = 10 +_METADATA_KEY = "key" +_METADATA_KEY_VALUE = "value" +_SSH_KEYS_NAME = "sshKeys" +_ITEMS = "items" +_METADATA = "metadata" BASE_DISK_ARGS = { "type": "PERSISTENT", @@ -1078,7 +1083,8 @@ class ComputeClient(base_cloud_client.BaseCloudApiClient): image_project=None, gpu=None, extra_disk_name=None, - labels=None): + labels=None, + extra_scopes=None): """Create a gce instance with a gce image. Args: @@ -1099,11 +1105,18 @@ class ComputeClient(base_cloud_client.BaseCloudApiClient): https://cloud.google.com/compute/docs/gpus/add-gpus extra_disk_name: String,the name of the extra disk to attach. labels: Dict, will be added to the instance's labels. + extra_scopes: A list of extra scopes to be provided to the instance. """ disk_args = (disk_args or self._GetDiskArgs(instance, image_name, image_project)) if extra_disk_name: disk_args.extend(self._GetExtraDiskArgs(extra_disk_name, zone)) + + scopes = [] + scopes.extend(self.DEFAULT_INSTANCE_SCOPE) + if extra_scopes: + scopes.extend(extra_scopes) + body = { "machineType": self.GetMachineType(machine_type, zone)["selfLink"], "name": instance, @@ -1111,10 +1124,11 @@ class ComputeClient(base_cloud_client.BaseCloudApiClient): "disks": disk_args, "serviceAccounts": [{ "email": "default", - "scopes": self.DEFAULT_INSTANCE_SCOPE + "scopes": scopes, }], } + if labels is not None: body["labels"] = labels if gpu: @@ -1127,10 +1141,10 @@ class ComputeClient(base_cloud_client.BaseCloudApiClient): body["scheduling"] = {"onHostMaintenance": "terminate"} if metadata: metadata_list = [{ - "key": key, - "value": val + _METADATA_KEY: key, + _METADATA_KEY_VALUE: val } for key, val in metadata.iteritems()] - body["metadata"] = {"items": metadata_list} + body[_METADATA] = {_ITEMS: metadata_list} logger.info("Creating instance: project %s, zone %s, body:%s", self._project, zone, body) api = self.service.instances().insert( @@ -1364,63 +1378,53 @@ class ComputeClient(base_cloud_client.BaseCloudApiClient): external_ip = instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"] return IP(internal=internal_ip, external=external_ip) - def SetCommonInstanceMetadata(self, body): - """Set project-wide metadata. + @utils.TimeExecute(function_description="Updating instance metadata: ") + def SetInstanceMetadata(self, zone, instance, body): + """Set instance metadata. Args: - body: Metadata body. + zone: String, name of zone. + instance: String, representing instance name. + body: Dict, Metadata body. metdata is in the following format. { "kind": "compute#metadata", "fingerprint": "a-23icsyx4E=", "items": [ { - "key": "google-compute-default-region", - "value": "us-central1" + "key": "sshKeys", + "value": "key" }, ... ] } """ - api = self.service.projects().setCommonInstanceMetadata( - project=self._project, body=body) + api = self.service.instances().setMetadata( + project=self._project, zone=zone, instance=instance, body=body) operation = self.Execute(api) - self.WaitOnOperation(operation, operation_scope=OperationScope.GLOBAL) + self.WaitOnOperation( + operation, operation_scope=OperationScope.ZONE, scope_name=zone) - def AddSshRsa(self, user, ssh_rsa_path): - """Add the public rsa key to the project's metadata. + def AddSshRsaInstanceMetadata(self, zone, user, ssh_rsa_path, instance): + """Add the public rsa key to the instance's metadata. - Compute engine instances that are created after will - by default contain the key. + Confirm that the instance has this public key in the instance's + metadata, if not we will add this public key. Args: - user: the name of the user which the key belongs to. - ssh_rsa_path: The absolute path to public rsa key. + zone: String, name of zone. + user: String, name of the user which the key belongs to. + ssh_rsa_path: String, The absolute path to public rsa key. + instance: String, representing instance name. """ - if not os.path.exists(ssh_rsa_path): - raise errors.DriverError( - "RSA file %s does not exist." % ssh_rsa_path) - - logger.info("Adding ssh rsa key from %s to project %s for user: %s", - ssh_rsa_path, self._project, user) - project = self.GetProject() - with open(ssh_rsa_path) as f: - rsa = f.read() - rsa = rsa.strip() if rsa else rsa - utils.VerifyRsaPubKey(rsa) - metadata = project["commonInstanceMetadata"] - for item in metadata.setdefault("items", []): - if item["key"] == "sshKeys": - sshkey_item = item - break - else: - sshkey_item = {"key": "sshKeys", "value": ""} - metadata["items"].append(sshkey_item) - + ssh_rsa_path = os.path.expanduser(ssh_rsa_path) + rsa = GetRsaKey(ssh_rsa_path) entry = "%s:%s" % (user, rsa) logger.debug("New RSA entry: %s", entry) - sshkey_item["value"] = "\n".join([sshkey_item["value"].strip(), - entry]).strip() - self.SetCommonInstanceMetadata(metadata) + + gce_instance = self.GetInstance(instance, zone) + metadata = gce_instance.get(_METADATA) + if RsaNotInMetadata(metadata, entry): + self.UpdateRsaInMetadata(zone, instance, metadata, entry) def CheckAccess(self): """Check if the user has read access to the cloud project. @@ -1443,3 +1447,88 @@ class ComputeClient(base_cloud_client.BaseCloudApiClient): return False raise return True + + def UpdateRsaInMetadata(self, zone, instance, metadata, entry): + """Update ssh public key to sshKeys's value in this metadata. + + Args: + zone: String, name of zone. + instance: String, representing instance name. + metadata: Dict, maps a metadata name to its value. + entry: String, ssh public key. + """ + ssh_key_item = GetSshKeyFromMetadata(metadata) + if ssh_key_item: + # The ssh key exists in the metadata so update the reference to it + # in the metadata. There may not be an actual ssh key value so + # that's why we filter for None to avoid an empty line in front. + ssh_key_item[_METADATA_KEY_VALUE] = "\n".join( + filter(None, [ssh_key_item[_METADATA_KEY_VALUE], entry])) + else: + # Since there is no ssh key item in the metadata, we need to add it in. + ssh_key_item = {_METADATA_KEY: _SSH_KEYS_NAME, + _METADATA_KEY_VALUE: entry} + metadata[_ITEMS].append(ssh_key_item) + utils.PrintColorString( + "Ssh public key doesn't exist in the instance(%s), adding it." + % instance, utils.TextColors.WARNING) + self.SetInstanceMetadata(zone, instance, metadata) + + +def RsaNotInMetadata(metadata, entry): + """Check ssh public key exist in sshKeys's value. + + Args: + metadata: Dict, maps a metadata name to its value. + entry: String, ssh public key. + + Returns: + Boolean. True if ssh public key doesn't exist in metadata. + """ + for item in metadata.setdefault(_ITEMS, []): + if item[_METADATA_KEY] == _SSH_KEYS_NAME: + if entry in item[_METADATA_KEY_VALUE]: + return False + return True + + +def GetSshKeyFromMetadata(metadata): + """Get ssh key item from metadata. + + Args: + metadata: Dict, maps a metadata name to its value. + + Returns: + Dict of ssk_key_item in metadata, None if can't find the ssh key item + in metadata. + """ + for item in metadata.setdefault(_ITEMS, []): + if item.get(_METADATA_KEY, '') == _SSH_KEYS_NAME: + return item + return None + + +def GetRsaKey(ssh_rsa_path): + """Get rsa key from rsa path. + + Args: + ssh_rsa_path: String, The absolute path to public rsa key. + + Returns: + String, rsa key. + + Raises: + errors.DriverError: RSA file does not exist. + """ + ssh_rsa_path = os.path.expanduser(ssh_rsa_path) + if not os.path.exists(ssh_rsa_path): + raise errors.DriverError( + "RSA file %s does not exist." % ssh_rsa_path) + + with open(ssh_rsa_path) as f: + rsa = f.read() + # The space must be removed here for string processing, + # if it is not string, it doesn't have a strip function. + rsa = rsa.strip() if rsa else rsa + utils.VerifyRsaPubKey(rsa) + return rsa |