Skip to content

Commit 500b9d7

Browse files
committed
Add OriginPatterns to AcceptOptions
Closes #194
1 parent fa720b9 commit 500b9d7

File tree

3 files changed

+74
-42
lines changed

3 files changed

+74
-42
lines changed

Diff for: accept.go

+47-26
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ import (
99
"errors"
1010
"fmt"
1111
"io"
12+
"log"
1213
"net/http"
1314
"net/textproto"
1415
"net/url"
15-
"strconv"
16+
"path/filepath"
1617
"strings"
1718

1819
"nhooyr.io/websocket/internal/errd"
@@ -25,18 +26,27 @@ type AcceptOptions struct {
2526
// reject it, close the connection when c.Subprotocol() == "".
2627
Subprotocols []string
2728

28-
// InsecureSkipVerify disables Accept's origin verification behaviour. By default,
29-
// the connection will only be accepted if the request origin is equal to the request
30-
// host.
29+
// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
3130
//
32-
// This is only required if you want javascript served from a different domain
33-
// to access your WebSocket server.
31+
// Deprecated: Use OriginPatterns with a match all pattern of * instead to control
32+
// origin authorization yourself.
33+
InsecureSkipVerify bool
34+
35+
// OriginPatterns lists the host patterns for authorized origins.
36+
// The request host is always authorized.
37+
// Use this to enable cross origin WebSockets.
38+
//
39+
// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
40+
// In such a case, example.com is the origin and chat.example.com is the request host.
41+
// One would set this field to []string{"example.com"} to authorize example.com to connect.
3442
//
35-
// See https://stackoverflow.com/a/37837709/4283659
43+
// Each pattern is matched case insensitively against the request origin host
44+
// with filepath.Match.
45+
// See https://golang.org/pkg/path/filepath/#Match
3646
//
3747
// Please ensure you understand the ramifications of enabling this.
3848
// If used incorrectly your WebSocket server will be open to CSRF attacks.
39-
InsecureSkipVerify bool
49+
OriginPatterns []string
4050

4151
// CompressionMode controls the compression mode.
4252
// Defaults to CompressionNoContextTakeover.
@@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
7787
}
7888

7989
if !opts.InsecureSkipVerify {
80-
err = authenticateOrigin(r)
90+
err = authenticateOrigin(r, opts.OriginPatterns)
8191
if err != nil {
92+
if errors.Is(err, filepath.ErrBadPattern) {
93+
log.Printf("websocket: %v", err)
94+
err = errors.New(http.StatusText(http.StatusForbidden))
95+
}
8296
http.Error(w, err.Error(), http.StatusForbidden)
8397
return nil, err
8498
}
@@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
165179
return 0, nil
166180
}
167181

168-
func authenticateOrigin(r *http.Request) error {
182+
func authenticateOrigin(r *http.Request, originHosts []string) error {
169183
origin := r.Header.Get("Origin")
170-
if origin != "" {
171-
u, err := url.Parse(origin)
184+
if origin == "" {
185+
return nil
186+
}
187+
188+
u, err := url.Parse(origin)
189+
if err != nil {
190+
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
191+
}
192+
193+
if strings.EqualFold(r.Host, u.Host) {
194+
return nil
195+
}
196+
197+
for _, hostPattern := range originHosts {
198+
matched, err := match(hostPattern, u.Host)
172199
if err != nil {
173-
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
200+
return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
174201
}
175-
if !strings.EqualFold(u.Host, r.Host) {
176-
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
202+
if matched {
203+
return nil
177204
}
178205
}
179-
return nil
206+
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
207+
}
208+
209+
func match(pattern, s string) (bool, error) {
210+
return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
180211
}
181212

182213
func selectSubprotocol(r *http.Request, subprotocols []string) string {
@@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
235266
return copts, nil
236267
}
237268

238-
// parseExtensionParameter parses the value in the extension parameter p.
239-
func parseExtensionParameter(p string) (int, bool) {
240-
ps := strings.Split(p, "=")
241-
if len(ps) == 1 {
242-
return 0, false
243-
}
244-
i, e := strconv.Atoi(strings.Trim(ps[1], `"`))
245-
return i, e == nil
246-
}
247-
248269
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
249270
copts := mode.opts()
250271
// The peer must explicitly request it.

Diff for: accept_test.go

+26-5
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,11 @@ func Test_authenticateOrigin(t *testing.T) {
244244
t.Parallel()
245245

246246
testCases := []struct {
247-
name string
248-
origin string
249-
host string
250-
success bool
247+
name string
248+
origin string
249+
host string
250+
originPatterns []string
251+
success bool
251252
}{
252253
{
253254
name: "none",
@@ -278,6 +279,26 @@ func Test_authenticateOrigin(t *testing.T) {
278279
host: "example.com",
279280
success: true,
280281
},
282+
{
283+
name: "originPatterns",
284+
origin: "https://two.examplE.com",
285+
host: "example.com",
286+
originPatterns: []string{
287+
"*.example.com",
288+
"bar.com",
289+
},
290+
success: true,
291+
},
292+
{
293+
name: "originPatternsUnauthorized",
294+
origin: "https://two.examplE.com",
295+
host: "example.com",
296+
originPatterns: []string{
297+
"exam3.com",
298+
"bar.com",
299+
},
300+
success: false,
301+
},
281302
}
282303

283304
for _, tc := range testCases {
@@ -288,7 +309,7 @@ func Test_authenticateOrigin(t *testing.T) {
288309
r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
289310
r.Header.Set("Origin", tc.origin)
290311

291-
err := authenticateOrigin(r)
312+
err := authenticateOrigin(r, tc.originPatterns)
292313
if tc.success {
293314
assert.Success(t, err)
294315
} else {

Diff for: example_test.go

+1-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"context"
77
"log"
88
"net/http"
9-
"net/url"
109
"time"
1110

1211
"nhooyr.io/websocket"
@@ -121,17 +120,8 @@ func Example_writeOnly() {
121120
// from the origin example.com.
122121
func Example_crossOrigin() {
123122
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124-
origin := r.Header.Get("Origin")
125-
if origin != "" {
126-
u, err := url.Parse(origin)
127-
if err != nil || u.Host != "example.com" {
128-
http.Error(w, "bad origin header", http.StatusForbidden)
129-
return
130-
}
131-
}
132-
133123
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
134-
InsecureSkipVerify: true,
124+
OriginPatterns: []string{"example.com"},
135125
})
136126
if err != nil {
137127
log.Println(err)

0 commit comments

Comments
 (0)