diff options
-rw-r--r-- | .travis.yml | 2 | ||||
-rw-r--r-- | oauth2client/client.py | 18 | ||||
-rw-r--r-- | oauth2client/service_account.py | 3 | ||||
-rw-r--r-- | tests/test_jwt.py | 6 | ||||
-rw-r--r-- | tests/test_oauth2client.py | 53 |
5 files changed, 46 insertions, 36 deletions
diff --git a/.travis.yml b/.travis.yml index 569b53b..a30cb68 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,8 @@ env: - TOX_ENV=py26openssl14 - TOX_ENV=py27openssl13 - TOX_ENV=py27openssl14 + - TOX_ENV=py33openssl14 + - TOX_ENV=py34openssl14 - TOX_ENV=pypyopenssl13 - TOX_ENV=pypyopenssl14 install: diff --git a/oauth2client/client.py b/oauth2client/client.py index 18eee46..c5b7214 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -409,7 +409,7 @@ def clean_headers(headers): clean = {} try: for k, v in six.iteritems(headers): - clean[str(k)] = str(v) + clean[k.encode('ascii')] = v.encode('ascii') except UnicodeEncodeError: raise NonAsciiHeaderError(k + ': ' + v) return clean @@ -1252,16 +1252,14 @@ def _get_well_known_file(): return default_config_path -def _get_application_default_credential_from_file( - application_default_credential_filename): +def _get_application_default_credential_from_file(filename): """Build the Application Default Credentials from file.""" from oauth2client import service_account # read the credentials from the file - with open(application_default_credential_filename) as ( - application_default_credential): - client_credentials = json.load(application_default_credential) + with open(filename) as file_obj: + client_credentials = json.load(file_obj) credentials_type = client_credentials.get('type') if credentials_type == AUTHORIZED_USER: @@ -1545,12 +1543,15 @@ def _extract_id_token(id_token): Does the extraction w/o checking the signature. Args: - id_token: string, OAuth 2.0 id_token. + id_token: string or bytestring, OAuth 2.0 id_token. Returns: object, The deserialized JSON payload. """ - segments = id_token.split('.') + if type(id_token) == bytes: + segments = id_token.split(b'.') + else: + segments = id_token.split(u'.') if len(segments) != 3: raise VerifyJwtTokenError( @@ -1578,6 +1579,7 @@ def _parse_exchange_token_response(content): except Exception: # different JSON libs raise different exceptions, # so we just do a catch-all here + content = content.decode('utf-8') resp = dict(urllib.parse.parse_qsl(content)) # some providers respond with 'expires', others with 'expires_in' diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index 1415d08..0353258 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -19,6 +19,7 @@ This credentials class is implemented on top of rsa library. import base64 import json +import six import time from pyasn1.codec.ber import decoder @@ -131,6 +132,8 @@ def _urlsafe_b64encode(data): def _get_private_key(private_key_pkcs8_text): """Get an RSA private key object from a pkcs8 representation.""" + if not isinstance(private_key_pkcs8_text, six.binary_type): + private_key_pkcs8_text = private_key_pkcs8_text.encode('ascii') der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) return rsa.PrivateKey.load_pkcs1( diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 169e5f8..a2cd37d 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -217,7 +217,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase): ]) http = credentials.authorize(http) resp, content = http.request('http://example.org') - self.assertEqual('Bearer 1/3w', content['Authorization']) + self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) def test_credentials_to_from_json(self): private_key = datafile('privatekey.%s' % self.format) @@ -254,7 +254,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase): content = self._credentials_refresh(credentials) - self.assertEqual('Bearer 3/3w', content['Authorization']) + self.assertEqual(b'Bearer 3/3w', content[b'Authorization']) def test_credentials_refresh_with_storage(self): private_key = datafile('privatekey.%s' % self.format) @@ -272,7 +272,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase): content = self._credentials_refresh(credentials) - self.assertEqual('Bearer 3/3w', content['Authorization']) + self.assertEqual(b'Bearer 3/3w', content[b'Authorization']) os.unlink(filename) diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py index 629f929..979eb81 100644 --- a/tests/test_oauth2client.py +++ b/tests/test_oauth2client.py @@ -545,7 +545,7 @@ class BasicCredentialsTests(unittest.TestCase): ]) http = self.credentials.authorize(http) resp, content = http.request('http://example.com') - self.assertEqual('Bearer 1/3w', content['Authorization']) + self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) self.assertFalse(self.credentials.access_token_expired) self.assertEqual(token_response, self.credentials.token_response) @@ -615,10 +615,10 @@ class BasicCredentialsTests(unittest.TestCase): http = credentials.authorize(http) http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'}) for k, v in six.iteritems(http.headers): - self.assertEqual(str, type(k)) - self.assertEqual(str, type(v)) + self.assertEqual(six.binary_type, type(k)) + self.assertEqual(six.binary_type, type(v)) - # Test again with unicode strings that can't simple be converted to ASCII. + # Test again with unicode strings that can't simply be converted to ASCII. try: http.request( u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'}) @@ -707,7 +707,7 @@ class AccessTokenCredentialsTests(unittest.TestCase): ]) http = self.credentials.authorize(http) resp, content = http.request('http://example.com') - self.assertEqual('Bearer foo', content['Authorization']) + self.assertEqual(b'Bearer foo', content[b'Authorization']) class TestAssertionCredentials(unittest.TestCase): @@ -738,7 +738,7 @@ class TestAssertionCredentials(unittest.TestCase): ]) http = self.credentials.authorize(http) resp, content = http.request('http://example.com') - self.assertEqual('Bearer 1/3w', content['Authorization']) + self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) def test_token_revoke_success(self): _token_revoke_test_helper( @@ -769,16 +769,18 @@ class ExtractIdTokenTest(unittest.TestCase): def test_extract_success(self): body = {'foo': 'bar'} - payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=') - jwt = 'stuff.' + payload + '.signature' + body_json = json.dumps(body).encode('ascii') + payload = base64.urlsafe_b64encode(body_json).strip(b'=') + jwt = b'stuff.' + payload + b'.signature' extracted = _extract_id_token(jwt) self.assertEqual(extracted, body) def test_extract_failure(self): body = {'foo': 'bar'} - payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=') - jwt = 'stuff.' + payload + body_json = json.dumps(body).encode('ascii') + payload = base64.urlsafe_b64encode(body_json).strip(b'=') + jwt = b'stuff.' + payload self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt) @@ -840,14 +842,14 @@ class OAuth2WebServerFlowTest(unittest.TestCase): def test_urlencoded_exchange_failure(self): http = HttpMockSequence([ - ({'status': '400'}, 'error=invalid_request'), + ({'status': '400'}, b'error=invalid_request'), ]) try: credentials = self.flow.step2_exchange('some random code', http=http) self.fail('should raise exception if exchange doesn\'t get 200') except FlowExchangeError as e: - self.assertEquals('invalid_request', str(e)) + self.assertEqual('invalid_request', str(e)) def test_exchange_failure_with_json_error(self): # Some providers have 'error' attribute as a JSON object @@ -894,12 +896,12 @@ class OAuth2WebServerFlowTest(unittest.TestCase): code = 'some random code' not_a_dict = FakeDict({'code': code}) - http = HttpMockSequence([ - ({'status': '200'}, - """{ "access_token":"SlAV32hkKG", - "expires_in":3600, - "refresh_token":"8xLOxBtZp8" }"""), - ]) + payload = (b'{' + b' "access_token":"SlAV32hkKG",' + b' "expires_in":3600,' + b' "refresh_token":"8xLOxBtZp8"' + b'}') + http = HttpMockSequence([({'status': '200'}, payload),]) credentials = self.flow.step2_exchange(not_a_dict, http=http) self.assertEqual('SlAV32hkKG', credentials.access_token) @@ -972,9 +974,10 @@ class OAuth2WebServerFlowTest(unittest.TestCase): def test_exchange_id_token(self): body = {'foo': 'bar'} - payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=') - jwt = (base64.urlsafe_b64encode('stuff')+ '.' + payload + '.' + - base64.urlsafe_b64encode('signature')) + body_json = json.dumps(body).encode('ascii') + payload = base64.urlsafe_b64encode(body_json).strip(b'=') + jwt = (base64.urlsafe_b64encode(b'stuff') + b'.' + payload + b'.' + + base64.urlsafe_b64encode(b'signature')) http = HttpMockSequence([ ({'status': '200'}, ("""{ "access_token":"SlAV32hkKG", @@ -994,7 +997,7 @@ class FlowFromCachedClientsecrets(unittest.TestCase): flow = flow_from_clientsecrets( 'some_secrets', '', redirect_uri='oob', cache=cache_mock) - self.assertEquals('foo_client_secret', flow.client_secret) + self.assertEqual('foo_client_secret', flow.client_secret) class CredentialsFromCodeTests(unittest.TestCase): @@ -1014,7 +1017,7 @@ class CredentialsFromCodeTests(unittest.TestCase): credentials = credentials_from_code(self.client_id, self.client_secret, self.scope, self.code, redirect_uri=self.redirect_uri, http=http) - self.assertEquals(credentials.access_token, token) + self.assertEqual(credentials.access_token, token) self.assertNotEqual(None, credentials.token_expiry) def test_exchange_code_for_token_fail(self): @@ -1039,7 +1042,7 @@ class CredentialsFromCodeTests(unittest.TestCase): credentials = credentials_from_clientsecrets_and_code( datafile('client_secrets.json'), self.scope, self.code, http=http) - self.assertEquals(credentials.access_token, 'asdfghjkl') + self.assertEqual(credentials.access_token, 'asdfghjkl') self.assertNotEqual(None, credentials.token_expiry) def test_exchange_code_and_cached_file_for_token(self): @@ -1052,7 +1055,7 @@ class CredentialsFromCodeTests(unittest.TestCase): credentials = credentials_from_clientsecrets_and_code( 'some_secrets', self.scope, self.code, http=http, cache=cache_mock) - self.assertEquals(credentials.access_token, 'asdfghjkl') + self.assertEqual(credentials.access_token, 'asdfghjkl') def test_exchange_code_and_file_for_token_fail(self): http = HttpMockSequence([ |