aboutsummaryrefslogtreecommitdiff
path: root/websocket/src/test/java/fi/iki/elonen/WebSocketResponseHandlerTest.java
blob: c007fdec06bdba1298f7f1b149eaac9d636a0acc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package fi.iki.elonen;

import java.io.IOException;

import org.junit.Test;

import fi.iki.elonen.NanoHTTPD.IHTTPSession;
import fi.iki.elonen.NanoHTTPD.Response;
import fi.iki.elonen.WebSocketFrame.CloseCode;
import fi.iki.elonen.testutil.MockHttpSession;
import static junit.framework.Assert.*;

public class WebSocketResponseHandlerTest {
    private WebSocketResponseHandler responseHandler = new WebSocketResponseHandler(new DummyWebSocketFactory(
            new WebSocketAdapter(new MockHttpSession())));
    
    @Test
    public void testHandshake_returnsExpectedHeaders() {
        MockHttpSession session = createWebSocketHandshakeRequest();

        Response handshakeResponse = responseHandler.serve(session);
        
        assertNotNull(handshakeResponse);
        assertEquals(101, handshakeResponse.getStatus().getRequestStatus());
        assertEquals("101 Switching Protocols", handshakeResponse.getStatus().getDescription());
        assertEquals("websocket", handshakeResponse.getHeader("upgrade"));
        assertEquals("Upgrade", handshakeResponse.getHeader("connection"));
        assertEquals("HSmrc0sMlYUkAGmm5OPpG2HaGWk=", handshakeResponse.getHeader("sec-websocket-accept"));
        assertEquals("chat", handshakeResponse.getHeader("sec-websocket-protocol"));
    }
    
    @Test
    public void testWrongWebsocketVersion_returnsErrorResponse() {
        MockHttpSession session = createWebSocketHandshakeRequest();
        session.getHeaders().put("sec-websocket-version", "12");

        Response handshakeResponse = responseHandler.serve(session);
        
        assertNotNull(handshakeResponse);
        assertEquals(400, handshakeResponse.getStatus().getRequestStatus());
        assertEquals("400 Bad Request", handshakeResponse.getStatus().getDescription());
    }

    private MockHttpSession createWebSocketHandshakeRequest() {
        // Example headers copied from Wikipedia
        MockHttpSession session = new MockHttpSession();
        session.getHeaders().put("upgrade", "websocket");
        session.getHeaders().put("connection", "Upgrade");
        session.getHeaders().put("sec-websocket-key", "x3JJHMbDL1EzLkh9GBhXDw==");
        session.getHeaders().put("sec-websocket-protocol", "chat, superchat");
        session.getHeaders().put("sec-websocket-version", "13");
        return session;
    }
    
    private static class DummyWebSocketFactory implements WebSocketFactory {
        private final WebSocket webSocket;
        
        private DummyWebSocketFactory(WebSocket webSocket) {
            super();
            this.webSocket = webSocket;
        }

        @Override
        public WebSocket openWebSocket(IHTTPSession handshake) {
            return webSocket;
        }
    }
    
    private static class WebSocketAdapter extends WebSocket {

        public WebSocketAdapter(IHTTPSession handshakeRequest) {
            super(handshakeRequest);
        }

        @Override
        protected void onPong(WebSocketFrame pongFrame) {
            throw new Error("this method should not have been called");
        }

        @Override
        protected void onMessage(WebSocketFrame messageFrame) {
            throw new Error("this method should not have been called");
        }

        @Override
        protected void onClose(CloseCode code, String reason,
                boolean initiatedByRemote) {
            throw new Error("this method should not have been called");
        }

        @Override
        protected void onException(IOException e) {
            throw new Error("this method should not have been called");
        }
        
    }
}