aboutsummaryrefslogtreecommitdiff
path: root/websocket
diff options
context:
space:
mode:
authorFumitoshi Ukai <ukai@google.com>2013-05-12 13:50:10 +0900
committerMikio Hara <mikioh.mikioh@gmail.com>2013-05-12 13:50:10 +0900
commit0005f0a0c0c361196e53cd81bb023f8f563b85cf (patch)
tree07fce9f83758fa83523f358660e27b24f31b013a /websocket
parent94458b3b475d36ace16e4fe3a3ea82c8f290d15a (diff)
downloadnet-0005f0a0c0c361196e53cd81bb023f8f563b85cf.tar.gz
go.net/websocket: allow server configurable
Add websocket.Server to configure WebSocket server handler. - Config.Header is additional headers to send, so you can use it to send cookies or so. To read cookies, you can use Conn.Request().Header. - factor out Handshake. You can set func to check origin, subprotocol etc. Handler checks origin by default. Fixes golang/go#4198. Fixes golang/go#5178. R=golang-dev, mikioh.mikioh, crobin CC=golang-dev https://golang.org/cl/8731044
Diffstat (limited to 'websocket')
-rw-r--r--websocket/client.go2
-rw-r--r--websocket/hybi.go49
-rw-r--r--websocket/hybi_test.go65
-rw-r--r--websocket/server.go65
-rw-r--r--websocket/websocket.go3
-rw-r--r--websocket/websocket_test.go63
6 files changed, 222 insertions, 25 deletions
diff --git a/websocket/client.go b/websocket/client.go
index e59da0b..df54a68 100644
--- a/websocket/client.go
+++ b/websocket/client.go
@@ -9,6 +9,7 @@ import (
"crypto/tls"
"io"
"net"
+ "net/http"
"net/url"
)
@@ -34,6 +35,7 @@ func NewConfig(server, origin string) (config *Config, err error) {
if err != nil {
return
}
+ config.Header = http.Header(make(map[string][]string))
return
}
diff --git a/websocket/hybi.go b/websocket/hybi.go
index 0023d1c..c6ba6cf 100644
--- a/websocket/hybi.go
+++ b/websocket/hybi.go
@@ -46,6 +46,17 @@ var (
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
ErrNotImplemented = &ProtocolError{"not implemented"}
+
+ handshakeHeader = map[string]bool{
+ "Host": true,
+ "Upgrade": true,
+ "Connection": true,
+ "Sec-Websocket-Key": true,
+ "Sec-Websocket-Origin": true,
+ "Sec-Websocket-Version": true,
+ "Sec-Websocket-Protocol": true,
+ "Sec-Websocket-Accept": true,
+ }
)
// A hybiFrameHeader is a frame header as defined in hybi draft.
@@ -408,8 +419,11 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er
if len(config.Protocol) > 0 {
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
}
- // TODO(ukai): send extensions.
- // TODO(ukai): send cookie if any.
+ // TODO(ukai): send Sec-WebSocket-Extensions.
+ err = config.Header.WriteSubset(bw, handshakeHeader)
+ if err != nil {
+ return err
+ }
bw.WriteString("\r\n")
if err = bw.Flush(); err != nil {
@@ -483,21 +497,14 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques
return http.StatusBadRequest, ErrChallengeResponse
}
version := req.Header.Get("Sec-Websocket-Version")
- var origin string
switch version {
case "13":
c.Version = ProtocolVersionHybi13
- origin = req.Header.Get("Origin")
case "8":
c.Version = ProtocolVersionHybi08
- origin = req.Header.Get("Sec-Websocket-Origin")
default:
return http.StatusBadRequest, ErrBadWebSocketVersion
}
- c.Origin, err = url.ParseRequestURI(origin)
- if err != nil {
- return http.StatusForbidden, err
- }
var scheme string
if req.TLS != nil {
scheme = "wss"
@@ -520,6 +527,22 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques
return http.StatusSwitchingProtocols, nil
}
+// Origin parses Origin header in "req".
+// If origin is "null", returns (nil, nil).
+func Origin(config *Config, req *http.Request) (*url.URL, error) {
+ var origin string
+ switch config.Version {
+ case ProtocolVersionHybi13:
+ origin = req.Header.Get("Origin")
+ case ProtocolVersionHybi08:
+ origin = req.Header.Get("Sec-Websocket-Origin")
+ }
+ if origin == "null" {
+ return nil, nil
+ }
+ return url.ParseRequestURI(origin)
+}
+
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
if len(c.Protocol) != 1 {
@@ -533,7 +556,13 @@ func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
}
- // TODO(ukai): support extensions
+ // TODO(ukai): send Sec-WebSocket-Extensions.
+ if c.Header != nil {
+ err := c.Header.WriteSubset(buf, handshakeHeader)
+ if err != nil {
+ return err
+ }
+ }
buf.WriteString("\r\n")
return buf.Flush()
}
diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go
index b527e0b..01ed9e9 100644
--- a/websocket/hybi_test.go
+++ b/websocket/hybi_test.go
@@ -92,6 +92,71 @@ Sec-WebSocket-Protocol: chat
}
}
+func TestHybiClientHandshakeWithHeader(t *testing.T) {
+ b := bytes.NewBuffer([]byte{})
+ bw := bufio.NewWriter(b)
+ br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
+Upgrade: websocket
+Connection: Upgrade
+Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+Sec-WebSocket-Protocol: chat
+
+`))
+ var err error
+ config := new(Config)
+ config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
+ if err != nil {
+ t.Fatal("location url", err)
+ }
+ config.Origin, err = url.ParseRequestURI("http://example.com")
+ if err != nil {
+ t.Fatal("origin url", err)
+ }
+ config.Protocol = append(config.Protocol, "chat")
+ config.Protocol = append(config.Protocol, "superchat")
+ config.Version = ProtocolVersionHybi13
+ config.Header = http.Header(make(map[string][]string))
+ config.Header.Add("User-Agent", "test")
+
+ config.handshakeData = map[string]string{
+ "key": "dGhlIHNhbXBsZSBub25jZQ==",
+ }
+ err = hybiClientHandshake(config, br, bw)
+ if err != nil {
+ t.Errorf("handshake failed: %v", err)
+ }
+ req, err := http.ReadRequest(bufio.NewReader(b))
+ if err != nil {
+ t.Fatalf("read request: %v", err)
+ }
+ if req.Method != "GET" {
+ t.Errorf("request method expected GET, but got %q", req.Method)
+ }
+ if req.URL.Path != "/chat" {
+ t.Errorf("request path expected /chat, but got %q", req.URL.Path)
+ }
+ if req.Proto != "HTTP/1.1" {
+ t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
+ }
+ if req.Host != "server.example.com" {
+ t.Errorf("request Host expected server.example.com, but got %v", req.Host)
+ }
+ var expectedHeader = map[string]string{
+ "Connection": "Upgrade",
+ "Upgrade": "websocket",
+ "Sec-Websocket-Key": config.handshakeData["key"],
+ "Origin": config.Origin.String(),
+ "Sec-Websocket-Protocol": "chat, superchat",
+ "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
+ "User-Agent": "test",
+ }
+ for k, v := range expectedHeader {
+ if req.Header.Get(k) != v {
+ t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
+ }
+ }
+}
+
func TestHybiClientHandshakeHybi08(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
diff --git a/websocket/server.go b/websocket/server.go
index 428bfb4..54e05b4 100644
--- a/websocket/server.go
+++ b/websocket/server.go
@@ -11,8 +11,7 @@ import (
"net/http"
)
-func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request) (conn *Conn, err error) {
- config := new(Config)
+func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
@@ -38,8 +37,16 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ
buf.Flush()
return
}
- config.Protocol = nil
-
+ if handshake != nil {
+ err = handshake(config, req)
+ if err != nil {
+ code = http.StatusForbidden
+ fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
+ buf.WriteString("\r\n")
+ buf.Flush()
+ return
+ }
+ }
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
@@ -52,11 +59,26 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ
return
}
-// Handler is an interface to a WebSocket.
-type Handler func(*Conn)
+// Server represents a server of a WebSocket.
+type Server struct {
+ // Config is a WebSocket configuration for new WebSocket connection.
+ Config
-// ServeHTTP implements the http.Handler interface for a Web Socket
-func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ // Handshake is an optional function in WebSocket handshake.
+ // For example, you can check, or don't check Origin header.
+ // Another example, you can select config.Protocol.
+ Handshake func(*Config, *http.Request) error
+
+ // Handler handles a WebSocket connection.
+ Handler
+}
+
+// ServeHTTP implements the http.Handler interface for a WebSocket
+func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ s.serveWebSocket(w, req)
+}
+
+func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
@@ -66,12 +88,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
- conn, err := newServerConn(rwc, buf, req)
+ conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
- h(conn)
+ s.Handler(conn)
+}
+
+// Handler is a simple interface to a WebSocket browser client.
+// It checks if Origin header is valid URL by default.
+// You might want to verify websocket.Conn.Config().Origin in the func.
+// If you use Server instead of Handler, you could call websocket.Origin and
+// check the origin in your Handshake func. So, if you want to accept
+// non-browser client, which doesn't send Origin header, you could use Server
+//. that doesn't check origin in its Handshake.
+type Handler func(*Conn)
+
+func checkOrigin(config *Config, req *http.Request) (err error) {
+ config.Origin, err = Origin(config, req)
+ if err == nil && config.Origin == nil {
+ return fmt.Errorf("null origin")
+ }
+ return err
+}
+
+// ServeHTTP implements the http.Handler interface for a WebSocket
+func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ s := Server{Handler: h, Handshake: checkOrigin}
+ s.serveWebSocket(w, req)
}
diff --git a/websocket/websocket.go b/websocket/websocket.go
index 793e510..861b3c6 100644
--- a/websocket/websocket.go
+++ b/websocket/websocket.go
@@ -87,6 +87,9 @@ type Config struct {
// TLS config for secure WebSocket (wss).
TlsConfig *tls.Config
+ // Additional header fields to be sent in WebSocket opening handshake.
+ Header http.Header
+
handshakeData map[string]string
}
diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go
index 40c147f..53e445b 100644
--- a/websocket/websocket_test.go
+++ b/websocket/websocket_test.go
@@ -44,9 +44,30 @@ func countServer(ws *Conn) {
}
}
+func subProtocolHandshake(config *Config, req *http.Request) error {
+ for _, proto := range config.Protocol {
+ if proto == "chat" {
+ config.Protocol = []string{proto}
+ return nil
+ }
+ }
+ return ErrBadWebSocketProtocol
+}
+
+func subProtoServer(ws *Conn) {
+ for _, proto := range ws.Config().Protocol {
+ io.WriteString(ws, proto)
+ }
+}
+
func startServer() {
http.Handle("/echo", Handler(echoServer))
http.Handle("/count", Handler(countServer))
+ subproto := Server{
+ Handshake: subProtocolHandshake,
+ Handler: Handler(subProtoServer),
+ }
+ http.Handle("/subproto", subproto)
server := httptest.NewServer(nil)
serverAddr = server.Listener.Addr().String()
log.Print("Test WebSocket server listening on ", serverAddr)
@@ -177,7 +198,7 @@ func TestWithQuery(t *testing.T) {
ws.Close()
}
-func TestWithProtocol(t *testing.T) {
+func testWithProtocol(t *testing.T, subproto []string) (string, error) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
@@ -185,15 +206,47 @@ func TestWithProtocol(t *testing.T) {
t.Fatal("dialing", err)
}
- config := newConfig(t, "/echo")
- config.Protocol = append(config.Protocol, "test")
+ config := newConfig(t, "/subproto")
+ config.Protocol = subproto
ws, err := NewClient(config, client)
if err != nil {
- t.Errorf("WebSocket handshake: %v", err)
- return
+ return "", err
+ }
+ msg := make([]byte, 16)
+ n, err := ws.Read(msg)
+ if err != nil {
+ return "", err
}
ws.Close()
+ return string(msg[:n]), nil
+}
+
+func TestWithProtocol(t *testing.T) {
+ proto, err := testWithProtocol(t, []string{"chat"})
+ if err != nil {
+ t.Errorf("SubProto: unexpected error: %v", err)
+ }
+ if proto != "chat" {
+ t.Errorf("SubProto: expected %q, got %q", "chat", proto)
+ }
+}
+
+func TestWithTwoProtocol(t *testing.T) {
+ proto, err := testWithProtocol(t, []string{"test", "chat"})
+ if err != nil {
+ t.Errorf("SubProto: unexpected error: %v", err)
+ }
+ if proto != "chat" {
+ t.Errorf("SubProto: expected %q, got %q", "chat", proto)
+ }
+}
+
+func TestWithBadProtocol(t *testing.T) {
+ _, err := testWithProtocol(t, []string{"test"})
+ if err != ErrBadStatus {
+ t.Errorf("SubProto: expected %q, got %q", ErrBadStatus)
+ }
}
func TestHTTP(t *testing.T) {