@@ -62,20 +62,47 @@ func TestAccept(t *testing.T) {
62
62
t .Run ("badCompression" , func (t * testing.T ) {
63
63
t .Parallel ()
64
64
65
- w := mockHijacker {
66
- ResponseWriter : httptest .NewRecorder (),
65
+ newRequest := func (extensions string ) * http.Request {
66
+ r := httptest .NewRequest ("GET" , "/" , nil )
67
+ r .Header .Set ("Connection" , "Upgrade" )
68
+ r .Header .Set ("Upgrade" , "websocket" )
69
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
70
+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
71
+ r .Header .Set ("Sec-WebSocket-Extensions" , extensions )
72
+ return r
73
+ }
74
+ newResponseWriter := func () http.ResponseWriter {
75
+ return mockHijacker {
76
+ ResponseWriter : httptest .NewRecorder (),
77
+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
78
+ return nil , nil , errors .New ("hijack error" )
79
+ },
80
+ }
67
81
}
68
- r := httptest .NewRequest ("GET" , "/" , nil )
69
- r .Header .Set ("Connection" , "Upgrade" )
70
- r .Header .Set ("Upgrade" , "websocket" )
71
- r .Header .Set ("Sec-WebSocket-Version" , "13" )
72
- r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
73
- r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
74
82
75
- _ , err := Accept (w , r , & AcceptOptions {
76
- CompressionMode : CompressionContextTakeover ,
83
+ t .Run ("withoutFallback" , func (t * testing.T ) {
84
+ t .Parallel ()
85
+
86
+ w := newResponseWriter ()
87
+ r := newRequest ("permessage-deflate; harharhar" )
88
+ _ , _ = Accept (w , r , & AcceptOptions {
89
+ CompressionMode : CompressionNoContextTakeover ,
90
+ })
91
+ assert .Equal (t , "extension header" , w .Header ().Get ("Sec-WebSocket-Extensions" ), "" )
92
+ })
93
+ t .Run ("withFallback" , func (t * testing.T ) {
94
+ t .Parallel ()
95
+
96
+ w := newResponseWriter ()
97
+ r := newRequest ("permessage-deflate; harharhar, permessage-deflate" )
98
+ _ , _ = Accept (w , r , & AcceptOptions {
99
+ CompressionMode : CompressionNoContextTakeover ,
100
+ })
101
+ assert .Equal (t , "extension header" ,
102
+ w .Header ().Get ("Sec-WebSocket-Extensions" ),
103
+ CompressionNoContextTakeover .opts ().String (),
104
+ )
77
105
})
78
- assert .Contains (t , err , `unsupported permessage-deflate parameter` )
79
106
})
80
107
81
108
t .Run ("requireHttpHijacker" , func (t * testing.T ) {
@@ -344,42 +371,53 @@ func Test_authenticateOrigin(t *testing.T) {
344
371
}
345
372
}
346
373
347
- func Test_acceptCompression (t * testing.T ) {
374
+ func Test_selectDeflate (t * testing.T ) {
348
375
t .Parallel ()
349
376
350
377
testCases := []struct {
351
- name string
352
- mode CompressionMode
353
- reqSecWebSocketExtensions string
354
- respSecWebSocketExtensions string
355
- expCopts * compressionOptions
356
- error bool
378
+ name string
379
+ mode CompressionMode
380
+ header string
381
+ expCopts * compressionOptions
382
+ expOK bool
357
383
}{
358
384
{
359
385
name : "disabled" ,
360
386
mode : CompressionDisabled ,
361
387
expCopts : nil ,
388
+ expOK : false ,
362
389
},
363
390
{
364
391
name : "noClientSupport" ,
365
392
mode : CompressionNoContextTakeover ,
366
393
expCopts : nil ,
394
+ expOK : false ,
367
395
},
368
396
{
369
- name : "permessage-deflate" ,
370
- mode : CompressionNoContextTakeover ,
371
- reqSecWebSocketExtensions : "permessage-deflate; client_max_window_bits" ,
372
- respSecWebSocketExtensions : "permessage-deflate; client_no_context_takeover; server_no_context_takeover" ,
397
+ name : "permessage-deflate" ,
398
+ mode : CompressionNoContextTakeover ,
399
+ header : "permessage-deflate; client_max_window_bits" ,
373
400
expCopts : & compressionOptions {
374
401
clientNoContextTakeover : true ,
375
402
serverNoContextTakeover : true ,
376
403
},
404
+ expOK : true ,
405
+ },
406
+ {
407
+ name : "permessage-deflate/unknown-parameter" ,
408
+ mode : CompressionNoContextTakeover ,
409
+ header : "permessage-deflate; meow" ,
410
+ expOK : false ,
377
411
},
378
412
{
379
- name : "permessage-deflate/error" ,
380
- mode : CompressionNoContextTakeover ,
381
- reqSecWebSocketExtensions : "permessage-deflate; meow" ,
382
- error : true ,
413
+ name : "permessage-deflate/unknown-parameter" ,
414
+ mode : CompressionNoContextTakeover ,
415
+ header : "permessage-deflate; meow, permessage-deflate; client_max_window_bits" ,
416
+ expCopts : & compressionOptions {
417
+ clientNoContextTakeover : true ,
418
+ serverNoContextTakeover : true ,
419
+ },
420
+ expOK : true ,
383
421
},
384
422
// {
385
423
// name: "x-webkit-deflate-frame",
@@ -404,19 +442,11 @@ func Test_acceptCompression(t *testing.T) {
404
442
t .Run (tc .name , func (t * testing.T ) {
405
443
t .Parallel ()
406
444
407
- r := httptest .NewRequest (http .MethodGet , "/" , nil )
408
- r .Header .Set ("Sec-WebSocket-Extensions" , tc .reqSecWebSocketExtensions )
409
-
410
- w := httptest .NewRecorder ()
411
- copts , err := acceptCompression (r , w , tc .mode )
412
- if tc .error {
413
- assert .Error (t , err )
414
- return
415
- }
416
-
417
- assert .Success (t , err )
445
+ h := http.Header {}
446
+ h .Set ("Sec-WebSocket-Extensions" , tc .header )
447
+ copts , ok := selectDeflate (websocketExtensions (h ), tc .mode )
448
+ assert .Equal (t , "selected options" , tc .expOK , ok )
418
449
assert .Equal (t , "compression options" , tc .expCopts , copts )
419
- assert .Equal (t , "Sec-WebSocket-Extensions" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
420
450
})
421
451
}
422
452
}
0 commit comments