aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.travis.yml2
-rw-r--r--oauth2client/client.py18
-rw-r--r--oauth2client/service_account.py3
-rw-r--r--tests/test_jwt.py6
-rw-r--r--tests/test_oauth2client.py53
5 files changed, 46 insertions, 36 deletions
diff --git a/.travis.yml b/.travis.yml
index 569b53b..a30cb68 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -5,6 +5,8 @@ env:
- TOX_ENV=py26openssl14
- TOX_ENV=py27openssl13
- TOX_ENV=py27openssl14
+ - TOX_ENV=py33openssl14
+ - TOX_ENV=py34openssl14
- TOX_ENV=pypyopenssl13
- TOX_ENV=pypyopenssl14
install:
diff --git a/oauth2client/client.py b/oauth2client/client.py
index 18eee46..c5b7214 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -409,7 +409,7 @@ def clean_headers(headers):
clean = {}
try:
for k, v in six.iteritems(headers):
- clean[str(k)] = str(v)
+ clean[k.encode('ascii')] = v.encode('ascii')
except UnicodeEncodeError:
raise NonAsciiHeaderError(k + ': ' + v)
return clean
@@ -1252,16 +1252,14 @@ def _get_well_known_file():
return default_config_path
-def _get_application_default_credential_from_file(
- application_default_credential_filename):
+def _get_application_default_credential_from_file(filename):
"""Build the Application Default Credentials from file."""
from oauth2client import service_account
# read the credentials from the file
- with open(application_default_credential_filename) as (
- application_default_credential):
- client_credentials = json.load(application_default_credential)
+ with open(filename) as file_obj:
+ client_credentials = json.load(file_obj)
credentials_type = client_credentials.get('type')
if credentials_type == AUTHORIZED_USER:
@@ -1545,12 +1543,15 @@ def _extract_id_token(id_token):
Does the extraction w/o checking the signature.
Args:
- id_token: string, OAuth 2.0 id_token.
+ id_token: string or bytestring, OAuth 2.0 id_token.
Returns:
object, The deserialized JSON payload.
"""
- segments = id_token.split('.')
+ if type(id_token) == bytes:
+ segments = id_token.split(b'.')
+ else:
+ segments = id_token.split(u'.')
if len(segments) != 3:
raise VerifyJwtTokenError(
@@ -1578,6 +1579,7 @@ def _parse_exchange_token_response(content):
except Exception:
# different JSON libs raise different exceptions,
# so we just do a catch-all here
+ content = content.decode('utf-8')
resp = dict(urllib.parse.parse_qsl(content))
# some providers respond with 'expires', others with 'expires_in'
diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py
index 1415d08..0353258 100644
--- a/oauth2client/service_account.py
+++ b/oauth2client/service_account.py
@@ -19,6 +19,7 @@ This credentials class is implemented on top of rsa library.
import base64
import json
+import six
import time
from pyasn1.codec.ber import decoder
@@ -131,6 +132,8 @@ def _urlsafe_b64encode(data):
def _get_private_key(private_key_pkcs8_text):
"""Get an RSA private key object from a pkcs8 representation."""
+ if not isinstance(private_key_pkcs8_text, six.binary_type):
+ private_key_pkcs8_text = private_key_pkcs8_text.encode('ascii')
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
return rsa.PrivateKey.load_pkcs1(
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
index 169e5f8..a2cd37d 100644
--- a/tests/test_jwt.py
+++ b/tests/test_jwt.py
@@ -217,7 +217,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
])
http = credentials.authorize(http)
resp, content = http.request('http://example.org')
- self.assertEqual('Bearer 1/3w', content['Authorization'])
+ self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_credentials_to_from_json(self):
private_key = datafile('privatekey.%s' % self.format)
@@ -254,7 +254,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
content = self._credentials_refresh(credentials)
- self.assertEqual('Bearer 3/3w', content['Authorization'])
+ self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
def test_credentials_refresh_with_storage(self):
private_key = datafile('privatekey.%s' % self.format)
@@ -272,7 +272,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
content = self._credentials_refresh(credentials)
- self.assertEqual('Bearer 3/3w', content['Authorization'])
+ self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
os.unlink(filename)
diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py
index 629f929..979eb81 100644
--- a/tests/test_oauth2client.py
+++ b/tests/test_oauth2client.py
@@ -545,7 +545,7 @@ class BasicCredentialsTests(unittest.TestCase):
])
http = self.credentials.authorize(http)
resp, content = http.request('http://example.com')
- self.assertEqual('Bearer 1/3w', content['Authorization'])
+ self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response, self.credentials.token_response)
@@ -615,10 +615,10 @@ class BasicCredentialsTests(unittest.TestCase):
http = credentials.authorize(http)
http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
for k, v in six.iteritems(http.headers):
- self.assertEqual(str, type(k))
- self.assertEqual(str, type(v))
+ self.assertEqual(six.binary_type, type(k))
+ self.assertEqual(six.binary_type, type(v))
- # Test again with unicode strings that can't simple be converted to ASCII.
+ # Test again with unicode strings that can't simply be converted to ASCII.
try:
http.request(
u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'})
@@ -707,7 +707,7 @@ class AccessTokenCredentialsTests(unittest.TestCase):
])
http = self.credentials.authorize(http)
resp, content = http.request('http://example.com')
- self.assertEqual('Bearer foo', content['Authorization'])
+ self.assertEqual(b'Bearer foo', content[b'Authorization'])
class TestAssertionCredentials(unittest.TestCase):
@@ -738,7 +738,7 @@ class TestAssertionCredentials(unittest.TestCase):
])
http = self.credentials.authorize(http)
resp, content = http.request('http://example.com')
- self.assertEqual('Bearer 1/3w', content['Authorization'])
+ self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_token_revoke_success(self):
_token_revoke_test_helper(
@@ -769,16 +769,18 @@ class ExtractIdTokenTest(unittest.TestCase):
def test_extract_success(self):
body = {'foo': 'bar'}
- payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=')
- jwt = 'stuff.' + payload + '.signature'
+ body_json = json.dumps(body).encode('ascii')
+ payload = base64.urlsafe_b64encode(body_json).strip(b'=')
+ jwt = b'stuff.' + payload + b'.signature'
extracted = _extract_id_token(jwt)
self.assertEqual(extracted, body)
def test_extract_failure(self):
body = {'foo': 'bar'}
- payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=')
- jwt = 'stuff.' + payload
+ body_json = json.dumps(body).encode('ascii')
+ payload = base64.urlsafe_b64encode(body_json).strip(b'=')
+ jwt = b'stuff.' + payload
self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt)
@@ -840,14 +842,14 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
def test_urlencoded_exchange_failure(self):
http = HttpMockSequence([
- ({'status': '400'}, 'error=invalid_request'),
+ ({'status': '400'}, b'error=invalid_request'),
])
try:
credentials = self.flow.step2_exchange('some random code', http=http)
self.fail('should raise exception if exchange doesn\'t get 200')
except FlowExchangeError as e:
- self.assertEquals('invalid_request', str(e))
+ self.assertEqual('invalid_request', str(e))
def test_exchange_failure_with_json_error(self):
# Some providers have 'error' attribute as a JSON object
@@ -894,12 +896,12 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
code = 'some random code'
not_a_dict = FakeDict({'code': code})
- http = HttpMockSequence([
- ({'status': '200'},
- """{ "access_token":"SlAV32hkKG",
- "expires_in":3600,
- "refresh_token":"8xLOxBtZp8" }"""),
- ])
+ payload = (b'{'
+ b' "access_token":"SlAV32hkKG",'
+ b' "expires_in":3600,'
+ b' "refresh_token":"8xLOxBtZp8"'
+ b'}')
+ http = HttpMockSequence([({'status': '200'}, payload),])
credentials = self.flow.step2_exchange(not_a_dict, http=http)
self.assertEqual('SlAV32hkKG', credentials.access_token)
@@ -972,9 +974,10 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
def test_exchange_id_token(self):
body = {'foo': 'bar'}
- payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=')
- jwt = (base64.urlsafe_b64encode('stuff')+ '.' + payload + '.' +
- base64.urlsafe_b64encode('signature'))
+ body_json = json.dumps(body).encode('ascii')
+ payload = base64.urlsafe_b64encode(body_json).strip(b'=')
+ jwt = (base64.urlsafe_b64encode(b'stuff') + b'.' + payload + b'.' +
+ base64.urlsafe_b64encode(b'signature'))
http = HttpMockSequence([
({'status': '200'}, ("""{ "access_token":"SlAV32hkKG",
@@ -994,7 +997,7 @@ class FlowFromCachedClientsecrets(unittest.TestCase):
flow = flow_from_clientsecrets(
'some_secrets', '', redirect_uri='oob', cache=cache_mock)
- self.assertEquals('foo_client_secret', flow.client_secret)
+ self.assertEqual('foo_client_secret', flow.client_secret)
class CredentialsFromCodeTests(unittest.TestCase):
@@ -1014,7 +1017,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
credentials = credentials_from_code(self.client_id, self.client_secret,
self.scope, self.code, redirect_uri=self.redirect_uri,
http=http)
- self.assertEquals(credentials.access_token, token)
+ self.assertEqual(credentials.access_token, token)
self.assertNotEqual(None, credentials.token_expiry)
def test_exchange_code_for_token_fail(self):
@@ -1039,7 +1042,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
credentials = credentials_from_clientsecrets_and_code(
datafile('client_secrets.json'), self.scope,
self.code, http=http)
- self.assertEquals(credentials.access_token, 'asdfghjkl')
+ self.assertEqual(credentials.access_token, 'asdfghjkl')
self.assertNotEqual(None, credentials.token_expiry)
def test_exchange_code_and_cached_file_for_token(self):
@@ -1052,7 +1055,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
credentials = credentials_from_clientsecrets_and_code(
'some_secrets', self.scope,
self.code, http=http, cache=cache_mock)
- self.assertEquals(credentials.access_token, 'asdfghjkl')
+ self.assertEqual(credentials.access_token, 'asdfghjkl')
def test_exchange_code_and_file_for_token_fail(self):
http = HttpMockSequence([