diff options
author | Fumitoshi Ukai <ukai@google.com> | 2013-05-12 13:50:10 +0900 |
---|---|---|
committer | Mikio Hara <mikioh.mikioh@gmail.com> | 2013-05-12 13:50:10 +0900 |
commit | 0005f0a0c0c361196e53cd81bb023f8f563b85cf (patch) | |
tree | 07fce9f83758fa83523f358660e27b24f31b013a /websocket | |
parent | 94458b3b475d36ace16e4fe3a3ea82c8f290d15a (diff) | |
download | net-0005f0a0c0c361196e53cd81bb023f8f563b85cf.tar.gz |
go.net/websocket: allow server configurable
Add websocket.Server to configure WebSocket server handler.
- Config.Header is additional headers to send, so you can use it
to send cookies or so.
To read cookies, you can use Conn.Request().Header.
- factor out Handshake.
You can set func to check origin, subprotocol etc.
Handler checks origin by default.
Fixes golang/go#4198.
Fixes golang/go#5178.
R=golang-dev, mikioh.mikioh, crobin
CC=golang-dev
https://golang.org/cl/8731044
Diffstat (limited to 'websocket')
-rw-r--r-- | websocket/client.go | 2 | ||||
-rw-r--r-- | websocket/hybi.go | 49 | ||||
-rw-r--r-- | websocket/hybi_test.go | 65 | ||||
-rw-r--r-- | websocket/server.go | 65 | ||||
-rw-r--r-- | websocket/websocket.go | 3 | ||||
-rw-r--r-- | websocket/websocket_test.go | 63 |
6 files changed, 222 insertions, 25 deletions
diff --git a/websocket/client.go b/websocket/client.go index e59da0b..df54a68 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "io" "net" + "net/http" "net/url" ) @@ -34,6 +35,7 @@ func NewConfig(server, origin string) (config *Config, err error) { if err != nil { return } + config.Header = http.Header(make(map[string][]string)) return } diff --git a/websocket/hybi.go b/websocket/hybi.go index 0023d1c..c6ba6cf 100644 --- a/websocket/hybi.go +++ b/websocket/hybi.go @@ -46,6 +46,17 @@ var ( ErrBadClosingStatus = &ProtocolError{"bad closing status"} ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} ErrNotImplemented = &ProtocolError{"not implemented"} + + handshakeHeader = map[string]bool{ + "Host": true, + "Upgrade": true, + "Connection": true, + "Sec-Websocket-Key": true, + "Sec-Websocket-Origin": true, + "Sec-Websocket-Version": true, + "Sec-Websocket-Protocol": true, + "Sec-Websocket-Accept": true, + } ) // A hybiFrameHeader is a frame header as defined in hybi draft. @@ -408,8 +419,11 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er if len(config.Protocol) > 0 { bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") } - // TODO(ukai): send extensions. - // TODO(ukai): send cookie if any. + // TODO(ukai): send Sec-WebSocket-Extensions. + err = config.Header.WriteSubset(bw, handshakeHeader) + if err != nil { + return err + } bw.WriteString("\r\n") if err = bw.Flush(); err != nil { @@ -483,21 +497,14 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques return http.StatusBadRequest, ErrChallengeResponse } version := req.Header.Get("Sec-Websocket-Version") - var origin string switch version { case "13": c.Version = ProtocolVersionHybi13 - origin = req.Header.Get("Origin") case "8": c.Version = ProtocolVersionHybi08 - origin = req.Header.Get("Sec-Websocket-Origin") default: return http.StatusBadRequest, ErrBadWebSocketVersion } - c.Origin, err = url.ParseRequestURI(origin) - if err != nil { - return http.StatusForbidden, err - } var scheme string if req.TLS != nil { scheme = "wss" @@ -520,6 +527,22 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques return http.StatusSwitchingProtocols, nil } +// Origin parses Origin header in "req". +// If origin is "null", returns (nil, nil). +func Origin(config *Config, req *http.Request) (*url.URL, error) { + var origin string + switch config.Version { + case ProtocolVersionHybi13: + origin = req.Header.Get("Origin") + case ProtocolVersionHybi08: + origin = req.Header.Get("Sec-Websocket-Origin") + } + if origin == "null" { + return nil, nil + } + return url.ParseRequestURI(origin) +} + func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { if len(c.Protocol) > 0 { if len(c.Protocol) != 1 { @@ -533,7 +556,13 @@ func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { if len(c.Protocol) > 0 { buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") } - // TODO(ukai): support extensions + // TODO(ukai): send Sec-WebSocket-Extensions. + if c.Header != nil { + err := c.Header.WriteSubset(buf, handshakeHeader) + if err != nil { + return err + } + } buf.WriteString("\r\n") return buf.Flush() } diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go index b527e0b..01ed9e9 100644 --- a/websocket/hybi_test.go +++ b/websocket/hybi_test.go @@ -92,6 +92,71 @@ Sec-WebSocket-Protocol: chat } } +func TestHybiClientHandshakeWithHeader(t *testing.T) { + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= +Sec-WebSocket-Protocol: chat + +`)) + var err error + config := new(Config) + config.Location, err = url.ParseRequestURI("ws://server.example.com/chat") + if err != nil { + t.Fatal("location url", err) + } + config.Origin, err = url.ParseRequestURI("http://example.com") + if err != nil { + t.Fatal("origin url", err) + } + config.Protocol = append(config.Protocol, "chat") + config.Protocol = append(config.Protocol, "superchat") + config.Version = ProtocolVersionHybi13 + config.Header = http.Header(make(map[string][]string)) + config.Header.Add("User-Agent", "test") + + config.handshakeData = map[string]string{ + "key": "dGhlIHNhbXBsZSBub25jZQ==", + } + err = hybiClientHandshake(config, br, bw) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + req, err := http.ReadRequest(bufio.NewReader(b)) + if err != nil { + t.Fatalf("read request: %v", err) + } + if req.Method != "GET" { + t.Errorf("request method expected GET, but got %q", req.Method) + } + if req.URL.Path != "/chat" { + t.Errorf("request path expected /chat, but got %q", req.URL.Path) + } + if req.Proto != "HTTP/1.1" { + t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto) + } + if req.Host != "server.example.com" { + t.Errorf("request Host expected server.example.com, but got %v", req.Host) + } + var expectedHeader = map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-Websocket-Key": config.handshakeData["key"], + "Origin": config.Origin.String(), + "Sec-Websocket-Protocol": "chat, superchat", + "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13), + "User-Agent": "test", + } + for k, v := range expectedHeader { + if req.Header.Get(k) != v { + t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k))) + } + } +} + func TestHybiClientHandshakeHybi08(t *testing.T) { b := bytes.NewBuffer([]byte{}) bw := bufio.NewWriter(b) diff --git a/websocket/server.go b/websocket/server.go index 428bfb4..54e05b4 100644 --- a/websocket/server.go +++ b/websocket/server.go @@ -11,8 +11,7 @@ import ( "net/http" ) -func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request) (conn *Conn, err error) { - config := new(Config) +func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { var hs serverHandshaker = &hybiServerHandshaker{Config: config} code, err := hs.ReadHandshake(buf.Reader, req) if err == ErrBadWebSocketVersion { @@ -38,8 +37,16 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ buf.Flush() return } - config.Protocol = nil - + if handshake != nil { + err = handshake(config, req) + if err != nil { + code = http.StatusForbidden + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.Flush() + return + } + } err = hs.AcceptHandshake(buf.Writer) if err != nil { code = http.StatusBadRequest @@ -52,11 +59,26 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ return } -// Handler is an interface to a WebSocket. -type Handler func(*Conn) +// Server represents a server of a WebSocket. +type Server struct { + // Config is a WebSocket configuration for new WebSocket connection. + Config -// ServeHTTP implements the http.Handler interface for a Web Socket -func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Handshake is an optional function in WebSocket handshake. + // For example, you can check, or don't check Origin header. + // Another example, you can select config.Protocol. + Handshake func(*Config, *http.Request) error + + // Handler handles a WebSocket connection. + Handler +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.serveWebSocket(w, req) +} + +func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { rwc, buf, err := w.(http.Hijacker).Hijack() if err != nil { panic("Hijack failed: " + err.Error()) @@ -66,12 +88,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // the client did not send a handshake that matches with protocol // specification. defer rwc.Close() - conn, err := newServerConn(rwc, buf, req) + conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) if err != nil { return } if conn == nil { panic("unexpected nil conn") } - h(conn) + s.Handler(conn) +} + +// Handler is a simple interface to a WebSocket browser client. +// It checks if Origin header is valid URL by default. +// You might want to verify websocket.Conn.Config().Origin in the func. +// If you use Server instead of Handler, you could call websocket.Origin and +// check the origin in your Handshake func. So, if you want to accept +// non-browser client, which doesn't send Origin header, you could use Server +//. that doesn't check origin in its Handshake. +type Handler func(*Conn) + +func checkOrigin(config *Config, req *http.Request) (err error) { + config.Origin, err = Origin(config, req) + if err == nil && config.Origin == nil { + return fmt.Errorf("null origin") + } + return err +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s := Server{Handler: h, Handshake: checkOrigin} + s.serveWebSocket(w, req) } diff --git a/websocket/websocket.go b/websocket/websocket.go index 793e510..861b3c6 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -87,6 +87,9 @@ type Config struct { // TLS config for secure WebSocket (wss). TlsConfig *tls.Config + // Additional header fields to be sent in WebSocket opening handshake. + Header http.Header + handshakeData map[string]string } diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 40c147f..53e445b 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -44,9 +44,30 @@ func countServer(ws *Conn) { } } +func subProtocolHandshake(config *Config, req *http.Request) error { + for _, proto := range config.Protocol { + if proto == "chat" { + config.Protocol = []string{proto} + return nil + } + } + return ErrBadWebSocketProtocol +} + +func subProtoServer(ws *Conn) { + for _, proto := range ws.Config().Protocol { + io.WriteString(ws, proto) + } +} + func startServer() { http.Handle("/echo", Handler(echoServer)) http.Handle("/count", Handler(countServer)) + subproto := Server{ + Handshake: subProtocolHandshake, + Handler: Handler(subProtoServer), + } + http.Handle("/subproto", subproto) server := httptest.NewServer(nil) serverAddr = server.Listener.Addr().String() log.Print("Test WebSocket server listening on ", serverAddr) @@ -177,7 +198,7 @@ func TestWithQuery(t *testing.T) { ws.Close() } -func TestWithProtocol(t *testing.T) { +func testWithProtocol(t *testing.T, subproto []string) (string, error) { once.Do(startServer) client, err := net.Dial("tcp", serverAddr) @@ -185,15 +206,47 @@ func TestWithProtocol(t *testing.T) { t.Fatal("dialing", err) } - config := newConfig(t, "/echo") - config.Protocol = append(config.Protocol, "test") + config := newConfig(t, "/subproto") + config.Protocol = subproto ws, err := NewClient(config, client) if err != nil { - t.Errorf("WebSocket handshake: %v", err) - return + return "", err + } + msg := make([]byte, 16) + n, err := ws.Read(msg) + if err != nil { + return "", err } ws.Close() + return string(msg[:n]), nil +} + +func TestWithProtocol(t *testing.T) { + proto, err := testWithProtocol(t, []string{"chat"}) + if err != nil { + t.Errorf("SubProto: unexpected error: %v", err) + } + if proto != "chat" { + t.Errorf("SubProto: expected %q, got %q", "chat", proto) + } +} + +func TestWithTwoProtocol(t *testing.T) { + proto, err := testWithProtocol(t, []string{"test", "chat"}) + if err != nil { + t.Errorf("SubProto: unexpected error: %v", err) + } + if proto != "chat" { + t.Errorf("SubProto: expected %q, got %q", "chat", proto) + } +} + +func TestWithBadProtocol(t *testing.T) { + _, err := testWithProtocol(t, []string{"test"}) + if err != ErrBadStatus { + t.Errorf("SubProto: expected %q, got %q", ErrBadStatus) + } } func TestHTTP(t *testing.T) { |