diff options
author | Mikio Hara <mikioh.mikioh@gmail.com> | 2015-03-07 18:53:53 +0900 |
---|---|---|
committer | Mikio Hara <mikioh.mikioh@gmail.com> | 2015-07-30 07:09:39 +0000 |
commit | d9b482f8ab1ac7e51338937406111ea7ce342512 (patch) | |
tree | 9f2053c57bf337d9a377d6d1b2453d4efd878c5e | |
parent | 7b0ed266d7d14de540f6f1804780ad5304539164 (diff) | |
download | net-d9b482f8ab1ac7e51338937406111ea7ce342512.tar.gz |
websocket: fix mis-handshake in the case of lack of HTTP origin header
Fixes golang/go#10102.
Change-Id: I34779a81797cb3b7e8820f5af8b0dde54f949164
Reviewed-on: https://go-review.googlesource.com/7034
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
-rw-r--r-- | websocket/hybi.go | 6 | ||||
-rw-r--r-- | websocket/websocket_test.go | 37 |
2 files changed, 40 insertions, 3 deletions
diff --git a/websocket/hybi.go b/websocket/hybi.go index f8c0b2e..41c854e 100644 --- a/websocket/hybi.go +++ b/websocket/hybi.go @@ -515,15 +515,15 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques return http.StatusSwitchingProtocols, nil } -// Origin parses Origin header in "req". -// If origin is "null", returns (nil, nil). +// Origin parses the Origin header in req. +// If the Origin header is not set, it returns nil and nil. func Origin(config *Config, req *http.Request) (*url.URL, error) { var origin string switch config.Version { case ProtocolVersionHybi13: origin = req.Header.Get("Origin") } - if origin == "null" { + if origin == "" { return nil, nil } return url.ParseRequestURI(origin) diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 725a79f..afc3eed 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -13,6 +13,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "reflect" "strings" "sync" "testing" @@ -450,3 +451,39 @@ func TestClose(t *testing.T) { t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed) } } + +var originTests = []struct { + req *http.Request + origin *url.URL +}{ + { + req: &http.Request{ + Header: http.Header{ + "Origin": []string{"http://www.example.com"}, + }, + }, + origin: &url.URL{ + Scheme: "http", + Host: "www.example.com", + }, + }, + { + req: &http.Request{}, + }, +} + +func TestOrigin(t *testing.T) { + conf := newConfig(t, "/echo") + conf.Version = ProtocolVersionHybi13 + for i, tt := range originTests { + origin, err := Origin(conf, tt.req) + if err != nil { + t.Error(err) + continue + } + if !reflect.DeepEqual(origin, tt.origin) { + t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin) + continue + } + } +} |