aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikio Hara <mikioh.mikioh@gmail.com>2015-03-07 18:53:53 +0900
committerMikio Hara <mikioh.mikioh@gmail.com>2015-07-30 07:09:39 +0000
commitd9b482f8ab1ac7e51338937406111ea7ce342512 (patch)
tree9f2053c57bf337d9a377d6d1b2453d4efd878c5e
parent7b0ed266d7d14de540f6f1804780ad5304539164 (diff)
downloadnet-d9b482f8ab1ac7e51338937406111ea7ce342512.tar.gz
websocket: fix mis-handshake in the case of lack of HTTP origin header
Fixes golang/go#10102. Change-Id: I34779a81797cb3b7e8820f5af8b0dde54f949164 Reviewed-on: https://go-review.googlesource.com/7034 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
-rw-r--r--websocket/hybi.go6
-rw-r--r--websocket/websocket_test.go37
2 files changed, 40 insertions, 3 deletions
diff --git a/websocket/hybi.go b/websocket/hybi.go
index f8c0b2e..41c854e 100644
--- a/websocket/hybi.go
+++ b/websocket/hybi.go
@@ -515,15 +515,15 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques
return http.StatusSwitchingProtocols, nil
}
-// Origin parses Origin header in "req".
-// If origin is "null", returns (nil, nil).
+// Origin parses the Origin header in req.
+// If the Origin header is not set, it returns nil and nil.
func Origin(config *Config, req *http.Request) (*url.URL, error) {
var origin string
switch config.Version {
case ProtocolVersionHybi13:
origin = req.Header.Get("Origin")
}
- if origin == "null" {
+ if origin == "" {
return nil, nil
}
return url.ParseRequestURI(origin)
diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go
index 725a79f..afc3eed 100644
--- a/websocket/websocket_test.go
+++ b/websocket/websocket_test.go
@@ -13,6 +13,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "reflect"
"strings"
"sync"
"testing"
@@ -450,3 +451,39 @@ func TestClose(t *testing.T) {
t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
}
}
+
+var originTests = []struct {
+ req *http.Request
+ origin *url.URL
+}{
+ {
+ req: &http.Request{
+ Header: http.Header{
+ "Origin": []string{"http://www.example.com"},
+ },
+ },
+ origin: &url.URL{
+ Scheme: "http",
+ Host: "www.example.com",
+ },
+ },
+ {
+ req: &http.Request{},
+ },
+}
+
+func TestOrigin(t *testing.T) {
+ conf := newConfig(t, "/echo")
+ conf.Version = ProtocolVersionHybi13
+ for i, tt := range originTests {
+ origin, err := Origin(conf, tt.req)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if !reflect.DeepEqual(origin, tt.origin) {
+ t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
+ continue
+ }
+ }
+}