diff options
author | Mikio Hara <mikioh.mikioh@gmail.com> | 2015-08-03 19:49:49 +0900 |
---|---|---|
committer | Mikio Hara <mikioh.mikioh@gmail.com> | 2015-08-04 03:35:05 +0000 |
commit | b963d2882af875a5309cdc4b00f93a638f3646d2 (patch) | |
tree | 47d43eb69f02f43a82db50573b99b444b72ac0d3 | |
parent | 84649876d01099f6fe7f9b81182d0b803ccbd612 (diff) | |
download | net-b963d2882af875a5309cdc4b00f93a638f3646d2.tar.gz |
websocket: handle solicited and unsolicited Ping/Pong frames correctly
This change prevents Read from failing with io.EOF, ErrNotImplemented on
exchanging control frames such as ping and pong.
Fixes golang/go#6377.
Fixes golang/go#7825.
Fixes golang/go#10156.
Change-Id: I600cf493de3671d7e3d11e2e12d32f43928b7bfc
Reviewed-on: https://go-review.googlesource.com/13054
Reviewed-by: Andrew Gerrand <adg@golang.org>
-rw-r--r-- | websocket/hybi.go | 19 | ||||
-rw-r--r-- | websocket/hybi_test.go | 11 | ||||
-rw-r--r-- | websocket/websocket_test.go | 95 |
3 files changed, 112 insertions, 13 deletions
diff --git a/websocket/hybi.go b/websocket/hybi.go index b965a34..c430dce 100644 --- a/websocket/hybi.go +++ b/websocket/hybi.go @@ -267,7 +267,7 @@ type hybiFrameHandler struct { payloadType byte } -func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) { +func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) { if handler.conn.IsServerConn() { // The client MUST mask all frames sent to the server. if frame.(*hybiFrameReader).header.MaskingKey == nil { @@ -291,20 +291,19 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, handler.payloadType = frame.PayloadType() case CloseFrame: return nil, io.EOF - case PingFrame: - pingMsg := make([]byte, maxControlFramePayloadLength) - n, err := io.ReadFull(frame, pingMsg) - if err != nil && err != io.ErrUnexpectedEOF { + case PingFrame, PongFrame: + b := make([]byte, maxControlFramePayloadLength) + n, err := io.ReadFull(frame, b) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { return nil, err } io.Copy(ioutil.Discard, frame) - n, err = handler.WritePong(pingMsg[:n]) - if err != nil { - return nil, err + if frame.PayloadType() == PingFrame { + if _, err := handler.WritePong(b[:n]); err != nil { + return nil, err + } } return nil, nil - case PongFrame: - return nil, ErrNotImplemented } return frame, nil } diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go index d6a1910..722ab31 100644 --- a/websocket/hybi_test.go +++ b/websocket/hybi_test.go @@ -326,7 +326,7 @@ func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []by } payload := make([]byte, len(testPayload)) _, err = r.Read(payload) - if err != nil { + if err != nil && err != io.EOF { t.Errorf("read %v", err) } if !bytes.Equal(testPayload, payload) { @@ -363,13 +363,20 @@ func TestHybiShortBinaryFrame(t *testing.T) { } func TestHybiControlFrame(t *testing.T) { - frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame} payload := []byte("hello") + + frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame} testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader) + frameHeader = &hybiFrameHeader{Fin: true, OpCode: PingFrame} + testHybiFrame(t, []byte{0x89, 0x00}, nil, nil, frameHeader) + frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame} testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader) + frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame} + testHybiFrame(t, []byte{0x8A, 0x00}, nil, nil, frameHeader) + frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame} payload = []byte{0x03, 0xe8} // 1000 testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader) diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 7841af2..05b7e53 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -24,7 +24,10 @@ import ( var serverAddr string var once sync.Once -func echoServer(ws *Conn) { io.Copy(ws, ws) } +func echoServer(ws *Conn) { + defer ws.Close() + io.Copy(ws, ws) +} type Count struct { S string @@ -32,6 +35,7 @@ type Count struct { } func countServer(ws *Conn) { + defer ws.Close() for { var count Count err := JSON.Receive(ws, &count) @@ -47,6 +51,55 @@ func countServer(ws *Conn) { } } +type testCtrlAndDataHandler struct { + hybiFrameHandler +} + +func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) { + h.hybiFrameHandler.conn.wio.Lock() + defer h.hybiFrameHandler.conn.wio.Unlock() + w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame) + if err != nil { + return 0, err + } + n, err := w.Write(b) + w.Close() + return n, err +} + +func ctrlAndDataServer(ws *Conn) { + defer ws.Close() + h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} + ws.frameHandler = h + + go func() { + for i := 0; ; i++ { + var b []byte + if i%2 != 0 { // with or without payload + b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i)) + } + if _, err := h.WritePing(b); err != nil { + break + } + if _, err := h.WritePong(b); err != nil { // unsolicited pong + break + } + time.Sleep(10 * time.Millisecond) + } + }() + + b := make([]byte, 128) + for { + n, err := ws.Read(b) + if err != nil { + break + } + if _, err := ws.Write(b[:n]); err != nil { + break + } + } +} + func subProtocolHandshake(config *Config, req *http.Request) error { for _, proto := range config.Protocol { if proto == "chat" { @@ -66,6 +119,7 @@ func subProtoServer(ws *Conn) { func startServer() { http.Handle("/echo", Handler(echoServer)) http.Handle("/count", Handler(countServer)) + http.Handle("/ctrldata", Handler(ctrlAndDataServer)) subproto := Server{ Handshake: subProtocolHandshake, Handler: Handler(subProtoServer), @@ -492,3 +546,42 @@ func TestOrigin(t *testing.T) { } } } + +func TestCtrlAndData(t *testing.T) { + once.Do(startServer) + + c, err := net.Dial("tcp", serverAddr) + if err != nil { + t.Fatal(err) + } + ws, err := NewClient(newConfig(t, "/ctrldata"), c) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} + ws.frameHandler = h + + b := make([]byte, 128) + for i := 0; i < 2; i++ { + data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i)) + if _, err := ws.Write(data); err != nil { + t.Fatalf("#%d: %v", i, err) + } + var ctrl []byte + if i%2 != 0 { // with or without payload + ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i)) + } + if _, err := h.WritePing(ctrl); err != nil { + t.Fatalf("#%d: %v", i, err) + } + n, err := ws.Read(b) + if err != nil { + t.Fatalf("#%d: %v", i, err) + } + if !bytes.Equal(b[:n], data) { + t.Fatalf("#%d: got %v; want %v", i, b[:n], data) + } + } +} |