aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorjuerg <juerg@google.com>2023-03-29 00:44:34 -0700
committerCopybara-Service <copybara-worker@google.com>2023-03-29 00:45:51 -0700
commit6c2a240a3c746966d060b5e9f22ed19057b0ec87 (patch)
treebb7ba83bbbe4877bfd4b58c56b6b7f4558192916 /python
parent75f1bab3e17b74d31e63969ecaf789749f475339 (diff)
downloadtink-6c2a240a3c746966d060b5e9f22ed19057b0ec87.tar.gz
Let AWS KMS integration in Python use boto3 instead of wrapping the C++ integration.
This should not change the behavior of the current API, it implements the same as cc/integration/awskms/aws_kms_client.cc in Python. PiperOrigin-RevId: 520254665
Diffstat (limited to 'python')
-rw-r--r--python/requirements.in1
-rw-r--r--python/requirements.txt36
-rw-r--r--python/tink/integration/awskms/BUILD.bazel4
-rw-r--r--python/tink/integration/awskms/__init__.py2
-rw-r--r--python/tink/integration/awskms/_aws_kms_client.py127
-rw-r--r--python/tink/integration/awskms/_aws_kms_client_test.py57
6 files changed, 208 insertions, 19 deletions
diff --git a/python/requirements.in b/python/requirements.in
index 1c826af47..e7aa75175 100644
--- a/python/requirements.in
+++ b/python/requirements.in
@@ -1,2 +1,3 @@
absl-py==1.3.0
protobuf==4.21.9
+boto3==1.26.89
diff --git a/python/requirements.txt b/python/requirements.txt
index 1c3b47bbe..98a1936e0 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -1,6 +1,6 @@
#
-# This file is autogenerated by pip-compile with python 3.10
-# To update, run:
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
#
# pip-compile --generate-hashes --output-file=requirements.txt requirements.in
#
@@ -8,6 +8,22 @@ absl-py==1.3.0 \
--hash=sha256:34995df9bd7a09b3b8749e230408f5a2a2dd7a68a0d33c12a3d0cb15a041a507 \
--hash=sha256:463c38a08d2e4cef6c498b76ba5bd4858e4c6ef51da1a5a1f27139a022e20248
# via -r requirements.in
+boto3==1.26.89 \
+ --hash=sha256:09929b24aaec4951e435d53d31f800e2ca52244af049dc11e5385ce062e106e9 \
+ --hash=sha256:e819812f16fab46fadf9b2853a46aaa126e108e7f038502dde555ebbbfc80133
+ # via -r requirements.in
+botocore==1.29.89 \
+ --hash=sha256:ac8da651f73a9d5759cf5d80beba68deda407e56aaaeb10d249fd557459f3b56 \
+ --hash=sha256:b757e59feca82ac62934f658918133116b4535cf66f1d72ff4935fa24e522527
+ # via
+ # boto3
+ # s3transfer
+jmespath==1.0.1 \
+ --hash=sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980 \
+ --hash=sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe
+ # via
+ # boto3
+ # botocore
protobuf==4.21.9 \
--hash=sha256:2c9c2ed7466ad565f18668aa4731c535511c5d9a40c6da39524bccf43e441719 \
--hash=sha256:48e2cd6b88c6ed3d5877a3ea40df79d08374088e89bedc32557348848dff250b \
@@ -24,3 +40,19 @@ protobuf==4.21.9 \
--hash=sha256:e575c57dc8b5b2b2caa436c16d44ef6981f2235eb7179bfc847557886376d740 \
--hash=sha256:f9eae277dd240ae19bb06ff4e2346e771252b0e619421965504bd1b1bba7c5fa
# via -r requirements.in
+python-dateutil==2.8.2 \
+ --hash=sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86 \
+ --hash=sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9
+ # via botocore
+s3transfer==0.6.0 \
+ --hash=sha256:06176b74f3a15f61f1b4f25a1fc29a4429040b7647133a463da8fa5bd28d5ecd \
+ --hash=sha256:2ed07d3866f523cc561bf4a00fc5535827981b117dd7876f036b0c1aca42c947
+ # via boto3
+six==1.16.0 \
+ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \
+ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254
+ # via python-dateutil
+urllib3==1.26.15 \
+ --hash=sha256:8a388717b9476f934a21484e8c8e61875ab60644d29b9b39e11e4b9dc1c6b305 \
+ --hash=sha256:aa751d169e23c7479ce47a0cb0da579e3ede798f994f5816a74e4f4500dcea42
+ # via botocore
diff --git a/python/tink/integration/awskms/BUILD.bazel b/python/tink/integration/awskms/BUILD.bazel
index 651315741..455136812 100644
--- a/python/tink/integration/awskms/BUILD.bazel
+++ b/python/tink/integration/awskms/BUILD.bazel
@@ -18,10 +18,11 @@ py_library(
srcs = ["_aws_kms_client.py"],
srcs_version = "PY3",
deps = [
+ "//tink:tink_python",
"//tink/aead",
"//tink/aead:_kms_aead_key_manager",
- "//tink/cc/pybind:tink_bindings",
"//tink/core",
+ requirement("boto3"),
],
)
@@ -35,6 +36,7 @@ py_test(
srcs_version = "PY3",
deps = [
":awskms",
+ ":_aws_kms_client",
"//tink:tink_python",
"//tink/testing:helper",
requirement("absl-py"),
diff --git a/python/tink/integration/awskms/__init__.py b/python/tink/integration/awskms/__init__.py
index a678c35f3..c86dc9c1a 100644
--- a/python/tink/integration/awskms/__init__.py
+++ b/python/tink/integration/awskms/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""StreamingAead package."""
+"""AWS KMS integration package."""
from tink.integration.awskms import _aws_kms_client
diff --git a/python/tink/integration/awskms/_aws_kms_client.py b/python/tink/integration/awskms/_aws_kms_client.py
index 86ebc433a..822e1b4b3 100644
--- a/python/tink/integration/awskms/_aws_kms_client.py
+++ b/python/tink/integration/awskms/_aws_kms_client.py
@@ -13,13 +13,95 @@
# limitations under the License.
"""A client for AWS KMS."""
+import binascii
+import configparser
import re
+from typing import Tuple, Any, Dict
+import boto3
+from botocore import exceptions
+
+import tink
from tink import aead
-from tink import core
from tink.aead import _kms_aead_key_manager
-from tink.cc.pybind import tink_bindings
+
+
+AWS_KEYURI_PREFIX = 'aws-kms://'
+
+
+def _encryption_context(associated_data: bytes) -> Dict[str, str]:
+ if associated_data:
+ hex_associated_data = binascii.hexlify(associated_data).decode('utf-8')
+ return {'associatedData': hex_associated_data}
+ else:
+ return dict()
+
+
+class _AwsKmsAead(aead.Aead):
+ """Implements the Aead interface for AWS KMS."""
+
+ def __init__(self, client: Any, key_arn: str) -> None:
+ self.client = client
+ self.key_arn = key_arn
+
+ def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
+ try:
+ response = self.client.encrypt(
+ KeyId=self.key_arn,
+ Plaintext=plaintext,
+ EncryptionContext=_encryption_context(associated_data),
+ )
+ return response['CiphertextBlob']
+ except exceptions.ClientError as e:
+ raise tink.TinkError(e)
+
+ def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
+ try:
+ response = self.client.decrypt(
+ KeyId=self.key_arn,
+ CiphertextBlob=ciphertext,
+ EncryptionContext=_encryption_context(associated_data),
+ )
+ if response['KeyId'] != self.key_arn:
+ raise tink.TinkError(
+ 'invalid key id: got %s, want %s'
+ % (self.key_arn, response['KeyId'])
+ )
+ return response['Plaintext']
+ except exceptions.ClientError as e:
+ raise tink.TinkError(e)
+
+
+def _key_uri_to_key_arn(key_uri: str) -> str:
+ if not key_uri.startswith(AWS_KEYURI_PREFIX):
+ raise tink.TinkError('invalid key URI')
+ return key_uri[len(AWS_KEYURI_PREFIX) :]
+
+
+def _parse_config(config_path: str) -> Tuple[str, str]:
+ """Returns ('aws_access_key_id', 'aws_secret_access_key') from a config."""
+ config = configparser.ConfigParser()
+ config.read(config_path)
+ if 'default' not in config:
+ raise ValueError('invalid config: default not found')
+ default = config['default']
+ if 'aws_access_key_id' not in default:
+ raise ValueError('invalid config: aws_access_key_id not found')
+ aws_access_key_id = default['aws_access_key_id']
+ if 'aws_secret_access_key' not in default:
+ raise ValueError('invalid config: aws_secret_access_key not found')
+ aws_secret_access_key = default['aws_secret_access_key']
+ return (aws_access_key_id, aws_secret_access_key)
+
+
+def _get_region_from_key_arn(key_arn: str) -> str:
+ # An AWS key ARN is of the form
+ # arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab.
+ key_arn_parts = key_arn.split(':')
+ if len(key_arn_parts) < 6:
+ raise tink.TinkError('invalid key id')
+ return key_arn_parts[3]
class AwsKmsClient(_kms_aead_key_manager.KmsClient):
@@ -44,19 +126,19 @@ class AwsKmsClient(_kms_aead_key_manager.KmsClient):
ValueError: If the path or filename of the credentials is invalid.
TinkError: If the key uri is not valid.
"""
-
- match = re.match('aws-kms://arn:aws:kms:([a-z0-9-]+):', key_uri)
if not key_uri:
- self._key_uri = ''
- elif match:
- self._key_uri = key_uri
+ self._key_arn = None
else:
- raise core.TinkError
-
- self.cc_client = tink_bindings.AwsKmsClient(key_uri, credentials_path)
+ match = re.match('aws-kms://arn:aws:kms:([a-z0-9-]+):', key_uri)
+ if not match:
+ raise tink.TinkError('invalid key URI')
+ self._key_arn = _key_uri_to_key_arn(key_uri)
+ aws_access_key_id, aws_secret_access_key = _parse_config(credentials_path)
+ self._aws_access_key_id = aws_access_key_id
+ self._aws_secret_access_key = aws_secret_access_key
def does_support(self, key_uri: str) -> bool:
- """Returns true iff this client supports KMS key specified in 'key_uri'.
+ """Returns true if this client supports KMS key specified in 'key_uri'.
Args:
key_uri: Text, URI of the key to be checked.
@@ -64,9 +146,12 @@ class AwsKmsClient(_kms_aead_key_manager.KmsClient):
Returns: A boolean value which is true if the key is supported and false
otherwise.
"""
- return self.cc_client.does_support(key_uri)
+ if not key_uri.startswith(AWS_KEYURI_PREFIX):
+ return False
+ if not self._key_arn:
+ return True
+ return _key_uri_to_key_arn(key_uri) == self._key_arn
- @core.use_tink_errors
def get_aead(self, key_uri: str) -> aead.Aead:
"""Returns an Aead-primitive backed by KMS key specified by 'key_uri'.
@@ -79,8 +164,20 @@ class AwsKmsClient(_kms_aead_key_manager.KmsClient):
Raises:
TinkError: If the key_uri is not supported.
"""
-
- return aead.AeadCcToPyWrapper(self.cc_client.get_aead(key_uri))
+ if not self.does_support(key_uri):
+ if self._key_arn:
+ raise tink.TinkError(
+ 'This client is bound to %s and cannot use key %s' %
+ (self._key_arn, key_uri))
+ raise tink.TinkError(
+ 'This client does not support key %s' % key_uri)
+ key_arn = _key_uri_to_key_arn(key_uri)
+ session = boto3.session.Session(
+ aws_access_key_id=self._aws_access_key_id,
+ aws_secret_access_key=self._aws_secret_access_key,
+ region_name=_get_region_from_key_arn(key_arn),
+ )
+ return _AwsKmsAead(session.client('kms'), key_arn)
@classmethod
def register_client(cls, key_uri, credentials_path) -> None:
diff --git a/python/tink/integration/awskms/_aws_kms_client_test.py b/python/tink/integration/awskms/_aws_kms_client_test.py
index 5a0901b8f..129f2de9e 100644
--- a/python/tink/integration/awskms/_aws_kms_client_test.py
+++ b/python/tink/integration/awskms/_aws_kms_client_test.py
@@ -15,12 +15,16 @@
import os
+import tempfile
+
from absl.testing import absltest
import tink
from tink.integration import awskms
+from tink.integration.awskms import _aws_kms_client
from tink.testing import helper
+
CREDENTIAL_PATH = os.path.join(helper.tink_py_testdata_path(),
'aws/credentials.ini')
KEY_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/'
@@ -63,6 +67,59 @@ class AwsKmsClientTest(absltest.TestCase):
with self.assertRaises(ValueError):
awskms.AwsKmsClient(KEY_URI, '../credentials.txt')
+ def test_parse_valid_credentials_works(self):
+ config_file = tempfile.NamedTemporaryFile(delete=False)
+ with open(config_file.name, 'w') as f:
+ f.write("""
+[otherSection]
+aws_access_key_id = other_key_id
+aws_secret_access_key = other_key
+
+[default]
+aws_access_key_id = key_id_123
+aws_secret_access_key = key_123""")
+
+ aws_access_key_id, aws_secret_access_key = _aws_kms_client._parse_config(
+ config_file.name
+ )
+ self.assertEqual(aws_access_key_id, 'key_id_123')
+ self.assertEqual(aws_secret_access_key, 'key_123')
+
+ os.unlink(config_file.name)
+
+ def test_parse_credentials_without_key_id_fails(self):
+ config_file = tempfile.NamedTemporaryFile(delete=False)
+ with open(config_file.name, 'w') as f:
+ f.write("""
+[default]
+aws_secret_access_key = key_123""")
+ with self.assertRaises(ValueError):
+ _aws_kms_client._parse_config(config_file.name)
+
+ os.unlink(config_file.name)
+
+ def test_parse_credentials_without_key_fails(self):
+ config_file = tempfile.NamedTemporaryFile(delete=False)
+ with open(config_file.name, 'w') as f:
+ f.write("""
+[default]
+aws_secret_access_key = key_123""")
+ with self.assertRaises(ValueError):
+ _aws_kms_client._parse_config(config_file.name)
+
+ os.unlink(config_file.name)
+
+ def test_parse_credentials_without_default_section_fails(self):
+ config_file = tempfile.NamedTemporaryFile(delete=False)
+ with open(config_file.name, 'w') as f:
+ f.write("""
+[otherSection]
+aws_access_key_id = other_key_id
+aws_secret_access_key = other_key""")
+ with self.assertRaises(ValueError):
+ _aws_kms_client._parse_config(config_file.name)
+
+ os.unlink(config_file.name)
if __name__ == '__main__':
absltest.main()