aboutsummaryrefslogtreecommitdiff
path: root/tests_async/transport/async_compliance.py
blob: 385a9236a1f276fc909fa3964eabb7cef2f8549c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import http.client
import time

import flask
import pytest
from pytest_localserver.http import WSGIServer

from google.auth import exceptions
from tests.transport import compliance


class RequestResponseTests(object):
    @pytest.fixture(scope="module")
    def server(self):
        """Provides a test HTTP server.

        The test server is automatically created before
        a test and destroyed at the end. The server is serving a test
        application that can be used to verify requests.
        """
        app = flask.Flask(__name__)
        app.debug = True

        # pylint: disable=unused-variable
        # (pylint thinks the flask routes are unusued.)
        @app.route("/basic")
        def index():
            header_value = flask.request.headers.get("x-test-header", "value")
            headers = {"X-Test-Header": header_value}
            return "Basic Content", http.client.OK, headers

        @app.route("/server_error")
        def server_error():
            return "Error", http.client.INTERNAL_SERVER_ERROR

        @app.route("/wait")
        def wait():
            time.sleep(3)
            return "Waited"

        # pylint: enable=unused-variable

        server = WSGIServer(application=app.wsgi_app)
        server.start()
        yield server
        server.stop()

    @pytest.mark.asyncio
    async def test_request_basic(self, server):
        request = self.make_request()
        response = await request(url=server.url + "/basic", method="GET")
        assert response.status == http.client.OK
        assert response.headers["x-test-header"] == "value"

        # Use 13 as this is the length of the data written into the stream.

        data = await response.data.read(13)
        assert data == b"Basic Content"

    @pytest.mark.asyncio
    async def test_request_basic_with_http(self, server):
        request = self.make_with_parameter_request()
        response = await request(url=server.url + "/basic", method="GET")
        assert response.status == http.client.OK
        assert response.headers["x-test-header"] == "value"

        # Use 13 as this is the length of the data written into the stream.

        data = await response.data.read(13)
        assert data == b"Basic Content"

    @pytest.mark.asyncio
    async def test_request_with_timeout_success(self, server):
        request = self.make_request()
        response = await request(url=server.url + "/basic", method="GET", timeout=2)

        assert response.status == http.client.OK
        assert response.headers["x-test-header"] == "value"

        data = await response.data.read(13)
        assert data == b"Basic Content"

    @pytest.mark.asyncio
    async def test_request_with_timeout_failure(self, server):
        request = self.make_request()

        with pytest.raises(exceptions.TransportError):
            await request(url=server.url + "/wait", method="GET", timeout=1)

    @pytest.mark.asyncio
    async def test_request_headers(self, server):
        request = self.make_request()
        response = await request(
            url=server.url + "/basic",
            method="GET",
            headers={"x-test-header": "hello world"},
        )

        assert response.status == http.client.OK
        assert response.headers["x-test-header"] == "hello world"

        data = await response.data.read(13)
        assert data == b"Basic Content"

    @pytest.mark.asyncio
    async def test_request_error(self, server):
        request = self.make_request()

        response = await request(url=server.url + "/server_error", method="GET")
        assert response.status == http.client.INTERNAL_SERVER_ERROR
        data = await response.data.read(5)
        assert data == b"Error"

    @pytest.mark.asyncio
    async def test_connection_error(self):
        request = self.make_request()

        with pytest.raises(exceptions.TransportError):
            await request(url="http://{}".format(compliance.NXDOMAIN), method="GET")