aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--oauth2client/client.py19
-rw-r--r--oauth2client/contrib/_metadata.py7
-rw-r--r--oauth2client/transport.py59
-rw-r--r--tests/contrib/test_metadata.py14
-rw-r--r--tests/test_client.py17
-rw-r--r--tests/test_transport.py40
6 files changed, 124 insertions, 32 deletions
diff --git a/oauth2client/client.py b/oauth2client/client.py
index cc2dd09..5bdf911 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -792,8 +792,9 @@ class OAuth2Credentials(Credentials):
headers = self._generate_refresh_request_headers()
logger.info('Refreshing access_token')
- resp, content = http_request(
- self.token_uri, method='POST', body=body, headers=headers)
+ resp, content = transport.request(
+ http_request, self.token_uri, method='POST',
+ body=body, headers=headers)
content = _helpers._from_bytes(content)
if resp.status == http_client.OK:
d = json.loads(content)
@@ -859,7 +860,7 @@ class OAuth2Credentials(Credentials):
logger.info('Revoking token')
query_params = {'token': token}
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
- resp, content = http_request(token_revoke_uri)
+ resp, content = transport.request(http_request, token_revoke_uri)
if resp.status == http_client.OK:
self.invalid = True
else:
@@ -903,7 +904,7 @@ class OAuth2Credentials(Credentials):
query_params = {'access_token': token, 'fields': 'scope'}
token_info_uri = _update_query_params(self.token_info_uri,
query_params)
- resp, content = http_request(token_info_uri)
+ resp, content = transport.request(http_request, token_info_uri)
content = _helpers._from_bytes(content)
if resp.status == http_client.OK:
d = json.loads(content)
@@ -1571,7 +1572,7 @@ def verify_id_token(id_token, audience, http=None,
if http is None:
http = transport.get_cached_http()
- resp, content = http.request(cert_uri)
+ resp, content = transport.request(http, cert_uri)
if resp.status == http_client.OK:
certs = json.loads(_helpers._from_bytes(content))
return crypt.verify_signed_jwt_with_certs(id_token, certs, audience)
@@ -1939,8 +1940,8 @@ class OAuth2WebServerFlow(Flow):
if http is None:
http = transport.get_http_object()
- resp, content = http.request(self.device_uri, method='POST', body=body,
- headers=headers)
+ resp, content = transport.request(
+ http, self.device_uri, method='POST', body=body, headers=headers)
content = _helpers._from_bytes(content)
if resp.status == http_client.OK:
try:
@@ -2022,8 +2023,8 @@ class OAuth2WebServerFlow(Flow):
if http is None:
http = transport.get_http_object()
- resp, content = http.request(self.token_uri, method='POST', body=body,
- headers=headers)
+ resp, content = transport.request(
+ http, self.token_uri, method='POST', body=body, headers=headers)
d = _parse_exchange_token_response(content)
if resp.status == http_client.OK and 'access_token' in d:
access_token = d['access_token']
diff --git a/oauth2client/contrib/_metadata.py b/oauth2client/contrib/_metadata.py
index 52af1ae..4c4588a 100644
--- a/oauth2client/contrib/_metadata.py
+++ b/oauth2client/contrib/_metadata.py
@@ -25,6 +25,7 @@ from six.moves.urllib import parse as urlparse
from oauth2client import _helpers
from oauth2client import client
+from oauth2client import transport
METADATA_ROOT = 'http://metadata.google.internal/computeMetadata/v1/'
@@ -55,10 +56,8 @@ def get(http_request, path, root=METADATA_ROOT, recursive=None):
url = urlparse.urljoin(root, path)
url = _helpers._add_query_parameter(url, 'recursive', recursive)
- response, content = http_request(
- url,
- headers=METADATA_HEADERS
- )
+ response, content = transport.request(
+ http_request, url, headers=METADATA_HEADERS)
if response.status == http_client.OK:
decoded = _helpers._from_bytes(content)
diff --git a/oauth2client/transport.py b/oauth2client/transport.py
index b42c7fd..ed256da 100644
--- a/oauth2client/transport.py
+++ b/oauth2client/transport.py
@@ -170,9 +170,9 @@ def wrap_http_for_auth(credentials, http):
_STREAM_PROPERTIES):
body_stream_position = body.tell()
- resp, content = orig_request_method(uri, method, body,
- clean_headers(headers),
- redirections, connection_type)
+ resp, content = request(orig_request_method, uri, method, body,
+ clean_headers(headers),
+ redirections, connection_type)
# A stored token may expire between the time it is retrieved and
# the time the request is made, so we may need to try twice.
@@ -188,9 +188,9 @@ def wrap_http_for_auth(credentials, http):
if body_stream_position is not None:
body.seek(body_stream_position)
- resp, content = orig_request_method(uri, method, body,
- clean_headers(headers),
- redirections, connection_type)
+ resp, content = request(orig_request_method, uri, method, body,
+ clean_headers(headers),
+ redirections, connection_type)
return resp, content
@@ -198,7 +198,7 @@ def wrap_http_for_auth(credentials, http):
http.request = new_request
# Set credentials as a property of the request method.
- setattr(http.request, 'credentials', credentials)
+ http.request.credentials = credentials
def wrap_http_for_jwt_access(credentials, http):
@@ -228,9 +228,9 @@ def wrap_http_for_jwt_access(credentials, http):
if (credentials.access_token is None or
credentials.access_token_expired):
credentials.refresh(None)
- return authenticated_request_method(uri, method, body,
- headers, redirections,
- connection_type)
+ return request(authenticated_request_method, uri,
+ method, body, headers, redirections,
+ connection_type)
else:
# If we don't have an 'aud' (audience) claim,
# create a 1-time token with the uri root as the audience
@@ -240,12 +240,45 @@ def wrap_http_for_jwt_access(credentials, http):
token, unused_expiry = credentials._create_token({'aud': uri_root})
headers['Authorization'] = 'Bearer ' + token
- return orig_request_method(uri, method, body,
- clean_headers(headers),
- redirections, connection_type)
+ return request(orig_request_method, uri, method, body,
+ clean_headers(headers),
+ redirections, connection_type)
# Replace the request method with our own closure.
http.request = new_request
+ # Set credentials as a property of the request method.
+ http.request.credentials = credentials
+
+
+def request(http, uri, method='GET', body=None, headers=None,
+ redirections=httplib2.DEFAULT_MAX_REDIRECTS,
+ connection_type=None):
+ """Make an HTTP request with an HTTP object and arguments.
+
+ Args:
+ http: httplib2.Http, an http object to be used to make requests.
+ uri: string, The URI to be requested.
+ method: string, The HTTP method to use for the request. Defaults
+ to 'GET'.
+ body: string, The payload / body in HTTP request. By default
+ there is no payload.
+ headers: dict, Key-value pairs of request headers. By default
+ there are no headers.
+ redirections: int, The number of allowed 203 redirects for
+ the request. Defaults to 5.
+ connection_type: httplib.HTTPConnection, a subclass to be used for
+ establishing connection. If not set, the type
+ will be determined from the ``uri``.
+
+ Returns:
+ tuple, a pair of a httplib2.Response with the status code and other
+ headers and the bytes of the content returned.
+ """
+ http_callable = getattr(http, 'request', http)
+ return http_callable(uri, method=method, body=body, headers=headers,
+ redirections=redirections,
+ connection_type=connection_type)
+
_CACHED_HTTP = httplib2.Http(MemoryCache())
diff --git a/tests/contrib/test_metadata.py b/tests/contrib/test_metadata.py
index 5af4ab9..6907f39 100644
--- a/tests/contrib/test_metadata.py
+++ b/tests/contrib/test_metadata.py
@@ -28,13 +28,23 @@ DATA = {'foo': 'bar'}
EXPECTED_URL = (
'http://metadata.google.internal/computeMetadata/v1/instance'
'/service-accounts/default')
-EXPECTED_KWARGS = dict(headers=_metadata.METADATA_HEADERS)
+EXPECTED_KWARGS = {
+ 'headers': _metadata.METADATA_HEADERS,
+ 'body': None,
+ 'connection_type': None,
+ 'method': 'GET',
+ 'redirections': 5,
+}
def request_mock(status, content_type, content):
response = http_mock.ResponseMock(
{'status': status, 'content-type': content_type})
- return mock.Mock(return_value=(response, content.encode('utf-8')))
+ request_method = mock.Mock(
+ return_value=(response, content.encode('utf-8')))
+ # Make sure the mock doesn't have a request attr.
+ del request_method.request
+ return request_method
class TestMetadata(unittest2.TestCase):
diff --git a/tests/test_client.py b/tests/test_client.py
index 79727d8..c3597da 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -1220,6 +1220,8 @@ class BasicCredentialsTests(unittest2.TestCase):
None, None, None)
credentials.store = store
http_request = mock.Mock()
+ # Make sure the mock doesn't have a request attr.
+ del http_request.request
http_request.return_value = response, content
with self.assertRaises(
@@ -1228,9 +1230,10 @@ class BasicCredentialsTests(unittest2.TestCase):
self.assertEqual(exc_manager.exception.args, (error_msg,))
self.assertEqual(exc_manager.exception.status, response.status)
- http_request.assert_called_once_with(None, body=gen_body.return_value,
- headers=gen_headers.return_value,
- method='POST')
+ http_request.assert_called_once_with(
+ None, method='POST', body=gen_body.return_value,
+ headers=gen_headers.return_value, redirections=5,
+ connection_type=None)
call1 = mock.call('Refreshing access_token')
failure_template = 'Failed to retrieve access token: %s'
@@ -1286,6 +1289,8 @@ class BasicCredentialsTests(unittest2.TestCase):
revoke_uri=oauth2client.GOOGLE_REVOKE_URI)
credentials.store = store
http_request = mock.Mock()
+ # Make sure the mock doesn't have a request attr.
+ del http_request.request
http_request.return_value = response, content
token = u's3kr3tz'
@@ -1306,7 +1311,9 @@ class BasicCredentialsTests(unittest2.TestCase):
store.delete.assert_not_called()
revoke_uri = oauth2client.GOOGLE_REVOKE_URI + '?token=' + token
- http_request.assert_called_once_with(revoke_uri)
+ http_request.assert_called_once_with(
+ revoke_uri, method='GET', body=None, headers=None,
+ redirections=5, connection_type=None)
logger.info.assert_called_once_with('Revoking token')
@@ -1353,6 +1360,8 @@ class BasicCredentialsTests(unittest2.TestCase):
None, None, None, None, None, None, None,
token_info_uri=oauth2client.GOOGLE_TOKEN_INFO_URI)
http_request = mock.Mock()
+ # Make sure the mock doesn't have a request attr.
+ del http_request.request
http_request.return_value = response, content
token = u's3kr3tz'
diff --git a/tests/test_transport.py b/tests/test_transport.py
index f783ed2..fdf1f73 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -136,3 +136,43 @@ class Test_wrap_http_for_auth(unittest2.TestCase):
self.assertIsNone(result)
self.assertNotEqual(http.request, orig_req_method)
self.assertIs(http.request.credentials, credentials)
+
+
+class Test_request(unittest2.TestCase):
+
+ uri = 'http://localhost'
+ method = 'POST'
+ body = 'abc'
+ redirections = 3
+
+ def test_with_request_attr(self):
+ http = mock.Mock()
+ mock_result = object()
+ mock_request = mock.Mock(return_value=mock_result)
+ http.request = mock_request
+
+ result = transport.request(http, self.uri, method=self.method,
+ body=self.body,
+ redirections=self.redirections)
+ self.assertIs(result, mock_result)
+ # Verify mock.
+ mock_request.assert_called_once_with(self.uri, method=self.method,
+ body=self.body,
+ redirections=self.redirections,
+ headers=None,
+ connection_type=None)
+
+ def test_with_callable_http(self):
+ mock_result = object()
+ http = mock.Mock(return_value=mock_result)
+ del http.request # Make sure the mock doesn't have a request attr.
+
+ result = transport.request(http, self.uri, method=self.method,
+ body=self.body,
+ redirections=self.redirections)
+ self.assertIs(result, mock_result)
+ # Verify mock.
+ http.assert_called_once_with(self.uri, method=self.method,
+ body=self.body,
+ redirections=self.redirections,
+ headers=None, connection_type=None)