From 500b9d734508a2c8ab09457ac9895b895bc86470 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 25 Feb 2020 22:20:19 -0500 Subject: [PATCH] Add OriginPatterns to AcceptOptions Closes #194 --- accept.go | 73 +++++++++++++++++++++++++++++++------------------ accept_test.go | 31 +++++++++++++++++---- example_test.go | 12 +------- 3 files changed, 74 insertions(+), 42 deletions(-) diff --git a/accept.go b/accept.go index 479138fc..47e20b52 100644 --- a/accept.go +++ b/accept.go @@ -9,10 +9,11 @@ import ( "errors" "fmt" "io" + "log" "net/http" "net/textproto" "net/url" - "strconv" + "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" @@ -25,18 +26,27 @@ type AcceptOptions struct { // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string - // InsecureSkipVerify disables Accept's origin verification behaviour. By default, - // the connection will only be accepted if the request origin is equal to the request - // host. + // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // - // This is only required if you want javascript served from a different domain - // to access your WebSocket server. + // Deprecated: Use OriginPatterns with a match all pattern of * instead to control + // origin authorization yourself. + InsecureSkipVerify bool + + // OriginPatterns lists the host patterns for authorized origins. + // The request host is always authorized. + // Use this to enable cross origin WebSockets. + // + // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. + // In such a case, example.com is the origin and chat.example.com is the request host. + // One would set this field to []string{"example.com"} to authorize example.com to connect. // - // See https://stackoverflow.com/a/37837709/4283659 + // Each pattern is matched case insensitively against the request origin host + // with filepath.Match. + // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. - InsecureSkipVerify bool + OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. @@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } if !opts.InsecureSkipVerify { - err = authenticateOrigin(r) + err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { + if errors.Is(err, filepath.ErrBadPattern) { + log.Printf("websocket: %v", err) + err = errors.New(http.StatusText(http.StatusForbidden)) + } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } @@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return 0, nil } -func authenticateOrigin(r *http.Request) error { +func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") - if origin != "" { - u, err := url.Parse(origin) + if origin == "" { + return nil + } + + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + + if strings.EqualFold(r.Host, u.Host) { + return nil + } + + for _, hostPattern := range originHosts { + matched, err := match(hostPattern, u.Host) if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if matched { + return nil } } - return nil + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) +} + +func match(pattern, s string) (bool, error) { + return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { @@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi return copts, nil } -// parseExtensionParameter parses the value in the extension parameter p. -func parseExtensionParameter(p string) (int, bool) { - ps := strings.Split(p, "=") - if len(ps) == 1 { - return 0, false - } - i, e := strconv.Atoi(strings.Trim(ps[1], `"`)) - return i, e == nil -} - func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. diff --git a/accept_test.go b/accept_test.go index 49667799..40a7b40c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -244,10 +244,11 @@ func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { - name string - origin string - host string - success bool + name string + origin string + host string + originPatterns []string + success bool }{ { name: "none", @@ -278,6 +279,26 @@ func Test_authenticateOrigin(t *testing.T) { host: "example.com", success: true, }, + { + name: "originPatterns", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "*.example.com", + "bar.com", + }, + success: true, + }, + { + name: "originPatternsUnauthorized", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "exam3.com", + "bar.com", + }, + success: false, + }, } for _, tc := range testCases { @@ -288,7 +309,7 @@ func Test_authenticateOrigin(t *testing.T) { r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) - err := authenticateOrigin(r) + err := authenticateOrigin(r, tc.originPatterns) if tc.success { assert.Success(t, err) } else { diff --git a/example_test.go b/example_test.go index 666914d2..c56e53f3 100644 --- a/example_test.go +++ b/example_test.go @@ -6,7 +6,6 @@ import ( "context" "log" "net/http" - "net/url" "time" "nhooyr.io/websocket" @@ -121,17 +120,8 @@ func Example_writeOnly() { // from the origin example.com. func Example_crossOrigin() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - if origin != "" { - u, err := url.Parse(origin) - if err != nil || u.Host != "example.com" { - http.Error(w, "bad origin header", http.StatusForbidden) - return - } - } - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - InsecureSkipVerify: true, + OriginPatterns: []string{"example.com"}, }) if err != nil { log.Println(err)