aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPat Ferate <pferate@users.noreply.github.com>2016-08-12 14:15:31 -0700
committerJon Wayne Parrott <jonwayne@google.com>2016-08-12 14:15:31 -0700
commit5137d7e8377266ef4beffe1c59c638c05c82cf10 (patch)
tree0d3744551ca792253d184e8f65a9721d29e434ec
parentc9b4b07525730338f2e560981b3fbe295d2146ab (diff)
downloadoauth2client-5137d7e8377266ef4beffe1c59c638c05c82cf10.tar.gz
Complete branches from partial test coverages (#629)
-rw-r--r--tests/contrib/test_sqlalchemy.py26
-rw-r--r--tests/test_client.py59
2 files changed, 66 insertions, 19 deletions
diff --git a/tests/contrib/test_sqlalchemy.py b/tests/contrib/test_sqlalchemy.py
index 67762f6..068aa92 100644
--- a/tests/contrib/test_sqlalchemy.py
+++ b/tests/contrib/test_sqlalchemy.py
@@ -15,6 +15,7 @@
import datetime
import unittest
+import mock
import sqlalchemy
import sqlalchemy.ext.declarative
import sqlalchemy.orm
@@ -66,7 +67,8 @@ class TestSQLAlchemyStorage(unittest.TestCase):
self.assertEqual(result.token_uri, self.credentials.token_uri)
self.assertEqual(result.user_agent, self.credentials.user_agent)
- def test_get(self):
+ @mock.patch('oauth2client.client.OAuth2Credentials.set_store')
+ def test_get(self, set_store):
session = self.session()
credentials_storage = oauth2client.contrib.sqlalchemy.Storage(
session=session,
@@ -75,7 +77,21 @@ class TestSQLAlchemyStorage(unittest.TestCase):
key_value=1,
property_name='credentials',
)
+ # No credentials stored
self.assertIsNone(credentials_storage.get())
+
+ # Invalid credentials stored
+ session.add(DummyModel(
+ key=1,
+ credentials=oauth2client.client.Credentials(),
+ ))
+ session.commit()
+ bad_credentials = credentials_storage.get()
+ self.assertIsInstance(bad_credentials, oauth2client.client.Credentials)
+ set_store.assert_not_called()
+
+ # Valid credentials stored
+ session.query(DummyModel).filter_by(key=1).delete()
session.add(DummyModel(
key=1,
credentials=self.credentials,
@@ -83,16 +99,20 @@ class TestSQLAlchemyStorage(unittest.TestCase):
session.commit()
self.compare_credentials(credentials_storage.get())
+ set_store.assert_called_with(credentials_storage)
def test_put(self):
session = self.session()
- oauth2client.contrib.sqlalchemy.Storage(
+ storage = oauth2client.contrib.sqlalchemy.Storage(
session=session,
model_class=DummyModel,
key_name='key',
key_value=1,
property_name='credentials',
- ).put(self.credentials)
+ )
+ # Store invalid credentials first to verify overwriting
+ storage.put(oauth2client.client.Credentials())
+ storage.put(self.credentials)
session.commit()
entity = session.query(DummyModel).filter_by(key=1).first()
diff --git a/tests/test_client.py b/tests/test_client.py
index 49a9210..27f24d8 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -1619,6 +1619,9 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
user_agent='unittest-sample/1.0',
revoke_uri='dummy_revoke_uri',
)
+ self.bad_verifier = b'__NOT_THE_VERIFIER_YOURE_LOOKING_FOR__'
+ self.good_verifier = b'__TEST_VERIFIER__'
+ self.good_challenger = b'__TEST_CHALLENGE__'
def test_construct_authorize_url(self):
authorize_url = self.flow.step1_get_authorize_url(state='state+1')
@@ -1691,19 +1694,42 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
@mock.patch('oauth2client.client._pkce.code_challenge')
@mock.patch('oauth2client.client._pkce.code_verifier')
def test_step1_get_authorize_url_pkce(self, fake_verifier, fake_challenge):
- fake_verifier.return_value = b'__TEST_VERIFIER__'
- fake_challenge.return_value = b'__TEST_CHALLENGE__'
+ fake_verifier.return_value = self.good_verifier
+ fake_challenge.return_value = self.good_challenger
flow = client.OAuth2WebServerFlow(
- 'client_id+1',
- scope='foo',
- redirect_uri='http://example.com',
- pkce=True)
+ 'client_id+1',
+ scope='foo',
+ redirect_uri='http://example.com',
+ pkce=True)
+ auth_url = urllib.parse.urlparse(flow.step1_get_authorize_url())
+ self.assertEqual(flow.code_verifier, self.good_verifier)
+ results = dict(urllib.parse.parse_qsl(auth_url.query))
+ self.assertEqual(
+ results['code_challenge'], self.good_challenger.decode())
+ self.assertEqual(results['code_challenge_method'], 'S256')
+ fake_verifier.assert_called()
+ fake_challenge.assert_called_with(self.good_verifier)
+
+ @mock.patch('oauth2client.client._pkce.code_challenge')
+ @mock.patch('oauth2client.client._pkce.code_verifier')
+ def test_step1_get_authorize_url_pkce_invalid_verifier(
+ self, fake_verifier, fake_challenge):
+ fake_verifier.return_value = self.good_verifier
+ fake_challenge.return_value = self.good_challenger
+ flow = client.OAuth2WebServerFlow(
+ 'client_id+1',
+ scope='foo',
+ redirect_uri='http://example.com',
+ pkce=True,
+ code_verifier=self.bad_verifier)
auth_url = urllib.parse.urlparse(flow.step1_get_authorize_url())
- self.assertEqual(flow.code_verifier, b'__TEST_VERIFIER__')
+ self.assertEqual(flow.code_verifier, self.bad_verifier)
results = dict(urllib.parse.parse_qsl(auth_url.query))
- self.assertEqual(results['code_challenge'], '__TEST_CHALLENGE__')
+ self.assertEqual(
+ results['code_challenge'], self.good_challenger.decode())
self.assertEqual(results['code_challenge_method'], 'S256')
- fake_challenge.assert_called_with(b'__TEST_VERIFIER__')
+ fake_verifier.assert_not_called()
+ fake_challenge.assert_called_with(self.bad_verifier)
def test_step1_get_authorize_url_without_redirect(self):
flow = client.OAuth2WebServerFlow('client_id+1', scope='foo',
@@ -1955,17 +1981,18 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
({'status': http_client.OK}, b'access_token=SlAV32hkKG'),
])
flow = client.OAuth2WebServerFlow(
- 'client_id+1',
- scope='foo',
- redirect_uri='http://example.com',
- pkce=True,
- code_verifier=b'__TEST_VERIFIER__'
- )
+ 'client_id+1',
+ scope='foo',
+ redirect_uri='http://example.com',
+ pkce=True,
+ code_verifier=self.good_verifier)
flow.step2_exchange(code='some random code', http=http)
self.assertEqual(len(http.requests), 1)
test_request = http.requests[0]
- self.assertIn('code_verifier=__TEST_VERIFIER__', test_request['body'])
+ self.assertIn(
+ 'code_verifier={0}'.format(self.good_verifier.decode()),
+ test_request['body'])
def test_exchange_using_authorization_header(self):
auth_header = 'Basic Y2xpZW50X2lkKzE6c2Vjexc_managerV0KzE=',