aboutsummaryrefslogtreecommitdiff
path: root/tests/contrib/test_flask_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/contrib/test_flask_util.py')
-rw-r--r--tests/contrib/test_flask_util.py94
1 files changed, 54 insertions, 40 deletions
diff --git a/tests/contrib/test_flask_util.py b/tests/contrib/test_flask_util.py
index 74cb218..112bff0 100644
--- a/tests/contrib/test_flask_util.py
+++ b/tests/contrib/test_flask_util.py
@@ -17,54 +17,31 @@
import datetime
import json
import logging
+import unittest
import flask
-import httplib2
import mock
import six.moves.http_client as httplib
import six.moves.urllib.parse as urlparse
-import unittest2
import oauth2client
from oauth2client import client
from oauth2client import clientsecrets
from oauth2client.contrib import flask_util
+from tests import http_mock
-__author__ = 'jonwayne@google.com (Jon Wayne Parrott)'
+DEFAULT_RESP = """\
+{
+ "access_token": "foo_access_token",
+ "expires_in": 3600,
+ "extra": "value",
+ "refresh_token": "foo_refresh_token"
+}
+"""
-class Http2Mock(object):
- """Mock httplib2.Http for code exchange / refresh"""
-
- def __init__(self, status=httplib.OK, **kwargs):
- self.status = status
- self.content = {
- 'access_token': 'foo_access_token',
- 'refresh_token': 'foo_refresh_token',
- 'expires_in': 3600,
- 'extra': 'value',
- }
- self.content.update(kwargs)
-
- def request(self, token_uri, method, body, headers, *args, **kwargs):
- self.body = body
- self.headers = headers
- return (self, json.dumps(self.content).encode('utf-8'))
-
- def __enter__(self):
- self.httplib2_orig = httplib2.Http
- httplib2.Http = self
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- httplib2.Http = self.httplib2_orig
-
- def __call__(self, *args, **kwargs):
- return self
-
-
-class FlaskOAuth2Tests(unittest2.TestCase):
+class FlaskOAuth2Tests(unittest.TestCase):
def setUp(self):
self.app = flask.Flask(__name__)
@@ -246,7 +223,12 @@ class FlaskOAuth2Tests(unittest2.TestCase):
def test_callback_view(self):
self.oauth2.storage = mock.Mock()
with self.app.test_client() as client:
- with Http2Mock() as http:
+ with mock.patch(
+ 'oauth2client.transport.get_http_object') as new_http:
+ # Set-up mock.
+ http = http_mock.HttpMock(data=DEFAULT_RESP)
+ new_http.return_value = http
+ # Run tests.
state = self._setup_callback_state(client)
response = client.get(
@@ -258,6 +240,9 @@ class FlaskOAuth2Tests(unittest2.TestCase):
self.assertIn('codez', http.body)
self.assertTrue(self.oauth2.storage.put.called)
+ # Check the mocks were called.
+ new_http.assert_called_once_with()
+
def test_authorize_callback(self):
self.oauth2.authorize_callback = mock.Mock()
self.test_callback_view()
@@ -273,6 +258,18 @@ class FlaskOAuth2Tests(unittest2.TestCase):
self.assertEqual(response.status_code, httplib.BAD_REQUEST)
self.assertIn('something', response.data.decode('utf-8'))
+ # Error supplied to callback with html
+ with self.app.test_client() as client:
+ with client.session_transaction() as session:
+ session['google_oauth2_csrf_token'] = 'tokenz'
+
+ response = client.get(
+ '/oauth2callback?state={}&error=<script>something<script>')
+ self.assertEqual(response.status_code, httplib.BAD_REQUEST)
+ self.assertIn(
+ '&lt;script&gt;something&lt;script&gt;',
+ response.data.decode('utf-8'))
+
# CSRF mismatch
with self.app.test_client() as client:
with client.session_transaction() as session:
@@ -296,11 +293,20 @@ class FlaskOAuth2Tests(unittest2.TestCase):
with self.app.test_client() as client:
state = self._setup_callback_state(client)
- with Http2Mock(status=httplib.INTERNAL_SERVER_ERROR):
+ with mock.patch(
+ 'oauth2client.transport.get_http_object') as new_http:
+ # Set-up mock.
+ new_http.return_value = http_mock.HttpMock(
+ headers={'status': httplib.INTERNAL_SERVER_ERROR},
+ data=DEFAULT_RESP)
+ # Run tests.
response = client.get(
'/oauth2callback?state={0}&code=codez'.format(state))
self.assertEqual(response.status_code, httplib.BAD_REQUEST)
+ # Check the mocks were called.
+ new_http.assert_called_once_with()
+
# Invalid state json
with self.app.test_client() as client:
with client.session_transaction() as session:
@@ -495,7 +501,10 @@ class FlaskOAuth2Tests(unittest2.TestCase):
def test_incremental_auth_exchange(self):
self._create_incremental_auth_app()
- with Http2Mock():
+ with mock.patch('oauth2client.transport.get_http_object') as new_http:
+ # Set-up mock.
+ new_http.return_value = http_mock.HttpMock(data=DEFAULT_RESP)
+ # Run tests.
with self.app.test_client() as client:
state = self._setup_callback_state(
client,
@@ -511,16 +520,21 @@ class FlaskOAuth2Tests(unittest2.TestCase):
self.assertTrue(
credentials.has_scopes(['email', 'one', 'two']))
+ # Check the mocks were called.
+ new_http.assert_called_once_with()
+
def test_refresh(self):
+ token_val = 'new_token'
+ json_resp = '{"access_token": "%s"}' % (token_val,)
+ http = http_mock.HttpMock(data=json_resp)
with self.app.test_request_context():
with mock.patch('flask.session'):
self.oauth2.storage.put(self._generate_credentials())
- self.oauth2.credentials.refresh(
- Http2Mock(access_token='new_token'))
+ self.oauth2.credentials.refresh(http)
self.assertEqual(
- self.oauth2.storage.get().access_token, 'new_token')
+ self.oauth2.storage.get().access_token, token_val)
def test_delete(self):
with self.app.test_request_context():