aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikio Hara <mikioh.mikioh@gmail.com>2015-08-03 19:49:49 +0900
committerMikio Hara <mikioh.mikioh@gmail.com>2015-08-04 03:35:05 +0000
commitb963d2882af875a5309cdc4b00f93a638f3646d2 (patch)
tree47d43eb69f02f43a82db50573b99b444b72ac0d3
parent84649876d01099f6fe7f9b81182d0b803ccbd612 (diff)
downloadnet-b963d2882af875a5309cdc4b00f93a638f3646d2.tar.gz
websocket: handle solicited and unsolicited Ping/Pong frames correctly
This change prevents Read from failing with io.EOF, ErrNotImplemented on exchanging control frames such as ping and pong. Fixes golang/go#6377. Fixes golang/go#7825. Fixes golang/go#10156. Change-Id: I600cf493de3671d7e3d11e2e12d32f43928b7bfc Reviewed-on: https://go-review.googlesource.com/13054 Reviewed-by: Andrew Gerrand <adg@golang.org>
-rw-r--r--websocket/hybi.go19
-rw-r--r--websocket/hybi_test.go11
-rw-r--r--websocket/websocket_test.go95
3 files changed, 112 insertions, 13 deletions
diff --git a/websocket/hybi.go b/websocket/hybi.go
index b965a34..c430dce 100644
--- a/websocket/hybi.go
+++ b/websocket/hybi.go
@@ -267,7 +267,7 @@ type hybiFrameHandler struct {
payloadType byte
}
-func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) {
+func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
if handler.conn.IsServerConn() {
// The client MUST mask all frames sent to the server.
if frame.(*hybiFrameReader).header.MaskingKey == nil {
@@ -291,20 +291,19 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader,
handler.payloadType = frame.PayloadType()
case CloseFrame:
return nil, io.EOF
- case PingFrame:
- pingMsg := make([]byte, maxControlFramePayloadLength)
- n, err := io.ReadFull(frame, pingMsg)
- if err != nil && err != io.ErrUnexpectedEOF {
+ case PingFrame, PongFrame:
+ b := make([]byte, maxControlFramePayloadLength)
+ n, err := io.ReadFull(frame, b)
+ if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err
}
io.Copy(ioutil.Discard, frame)
- n, err = handler.WritePong(pingMsg[:n])
- if err != nil {
- return nil, err
+ if frame.PayloadType() == PingFrame {
+ if _, err := handler.WritePong(b[:n]); err != nil {
+ return nil, err
+ }
}
return nil, nil
- case PongFrame:
- return nil, ErrNotImplemented
}
return frame, nil
}
diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go
index d6a1910..722ab31 100644
--- a/websocket/hybi_test.go
+++ b/websocket/hybi_test.go
@@ -326,7 +326,7 @@ func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []by
}
payload := make([]byte, len(testPayload))
_, err = r.Read(payload)
- if err != nil {
+ if err != nil && err != io.EOF {
t.Errorf("read %v", err)
}
if !bytes.Equal(testPayload, payload) {
@@ -363,13 +363,20 @@ func TestHybiShortBinaryFrame(t *testing.T) {
}
func TestHybiControlFrame(t *testing.T) {
- frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame}
payload := []byte("hello")
+
+ frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame}
testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader)
+ frameHeader = &hybiFrameHeader{Fin: true, OpCode: PingFrame}
+ testHybiFrame(t, []byte{0x89, 0x00}, nil, nil, frameHeader)
+
frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame}
testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader)
+ frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame}
+ testHybiFrame(t, []byte{0x8A, 0x00}, nil, nil, frameHeader)
+
frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame}
payload = []byte{0x03, 0xe8} // 1000
testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader)
diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go
index 7841af2..05b7e53 100644
--- a/websocket/websocket_test.go
+++ b/websocket/websocket_test.go
@@ -24,7 +24,10 @@ import (
var serverAddr string
var once sync.Once
-func echoServer(ws *Conn) { io.Copy(ws, ws) }
+func echoServer(ws *Conn) {
+ defer ws.Close()
+ io.Copy(ws, ws)
+}
type Count struct {
S string
@@ -32,6 +35,7 @@ type Count struct {
}
func countServer(ws *Conn) {
+ defer ws.Close()
for {
var count Count
err := JSON.Receive(ws, &count)
@@ -47,6 +51,55 @@ func countServer(ws *Conn) {
}
}
+type testCtrlAndDataHandler struct {
+ hybiFrameHandler
+}
+
+func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
+ h.hybiFrameHandler.conn.wio.Lock()
+ defer h.hybiFrameHandler.conn.wio.Unlock()
+ w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
+ if err != nil {
+ return 0, err
+ }
+ n, err := w.Write(b)
+ w.Close()
+ return n, err
+}
+
+func ctrlAndDataServer(ws *Conn) {
+ defer ws.Close()
+ h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
+ ws.frameHandler = h
+
+ go func() {
+ for i := 0; ; i++ {
+ var b []byte
+ if i%2 != 0 { // with or without payload
+ b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
+ }
+ if _, err := h.WritePing(b); err != nil {
+ break
+ }
+ if _, err := h.WritePong(b); err != nil { // unsolicited pong
+ break
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ }()
+
+ b := make([]byte, 128)
+ for {
+ n, err := ws.Read(b)
+ if err != nil {
+ break
+ }
+ if _, err := ws.Write(b[:n]); err != nil {
+ break
+ }
+ }
+}
+
func subProtocolHandshake(config *Config, req *http.Request) error {
for _, proto := range config.Protocol {
if proto == "chat" {
@@ -66,6 +119,7 @@ func subProtoServer(ws *Conn) {
func startServer() {
http.Handle("/echo", Handler(echoServer))
http.Handle("/count", Handler(countServer))
+ http.Handle("/ctrldata", Handler(ctrlAndDataServer))
subproto := Server{
Handshake: subProtocolHandshake,
Handler: Handler(subProtoServer),
@@ -492,3 +546,42 @@ func TestOrigin(t *testing.T) {
}
}
}
+
+func TestCtrlAndData(t *testing.T) {
+ once.Do(startServer)
+
+ c, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ws, err := NewClient(newConfig(t, "/ctrldata"), c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ws.Close()
+
+ h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
+ ws.frameHandler = h
+
+ b := make([]byte, 128)
+ for i := 0; i < 2; i++ {
+ data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
+ if _, err := ws.Write(data); err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ var ctrl []byte
+ if i%2 != 0 { // with or without payload
+ ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
+ }
+ if _, err := h.WritePing(ctrl); err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ n, err := ws.Read(b)
+ if err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ if !bytes.Equal(b[:n], data) {
+ t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
+ }
+ }
+}