Skip to content

Commit b3b0226

Browse files
committed
add more tests
1 parent 1bb981e commit b3b0226

File tree

1 file changed

+256
-0
lines changed

1 file changed

+256
-0
lines changed

service/internal/server/server_test.go

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package server
22

33
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"strings"
47
"testing"
58

9+
"github.com/go-chi/cors"
610
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
712
)
813

914
func TestMergeStringSlices(t *testing.T) {
@@ -266,3 +271,254 @@ func TestCORSConfig_EffectiveExposedHeaders(t *testing.T) {
266271
})
267272
}
268273
}
274+
275+
// TestCORSMiddleware_WildcardOrigin tests that the CORS middleware correctly handles
276+
// wildcard origin configuration with credentials enabled. Per CORS spec, when
277+
// credentials are allowed, the response must reflect the actual origin, not "*".
278+
func TestCORSMiddleware_WildcardOrigin(t *testing.T) {
279+
// Configure CORS the same way as newHTTPServer does
280+
cfg := CORSConfig{
281+
Enabled: true,
282+
AllowedOrigins: []string{"*"},
283+
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
284+
AllowedHeaders: []string{"Authorization", "Content-Type"},
285+
ExposedHeaders: []string{"Link"},
286+
AllowCredentials: true,
287+
MaxAge: 3600,
288+
}
289+
290+
// Create the CORS handler using the same pattern as server.go
291+
corsHandler := cors.New(cors.Options{
292+
AllowOriginFunc: func(_ *http.Request, origin string) bool {
293+
for _, allowedOrigin := range cfg.AllowedOrigins {
294+
if allowedOrigin == "*" {
295+
return true
296+
}
297+
if strings.EqualFold(origin, allowedOrigin) {
298+
return true
299+
}
300+
}
301+
return false
302+
},
303+
AllowedMethods: cfg.EffectiveMethods(),
304+
AllowedHeaders: cfg.EffectiveHeaders(),
305+
ExposedHeaders: cfg.EffectiveExposedHeaders(),
306+
AllowCredentials: cfg.AllowCredentials,
307+
MaxAge: cfg.MaxAge,
308+
})
309+
310+
// Create a simple handler wrapped with CORS
311+
handler := corsHandler.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
312+
w.WriteHeader(http.StatusOK)
313+
}))
314+
315+
tests := []struct {
316+
name string
317+
origin string
318+
method string
319+
requestHeaders string
320+
wantOrigin string
321+
wantCreds string
322+
}{
323+
{
324+
name: "preflight from localhost:3000",
325+
origin: "http://localhost:3000",
326+
method: http.MethodOptions,
327+
requestHeaders: "authorization,content-type",
328+
wantOrigin: "http://localhost:3000",
329+
wantCreds: "true",
330+
},
331+
{
332+
name: "preflight from example.com",
333+
origin: "https://example.com",
334+
method: http.MethodOptions,
335+
requestHeaders: "authorization",
336+
wantOrigin: "https://example.com",
337+
wantCreds: "true",
338+
},
339+
{
340+
name: "preflight from arbitrary origin",
341+
origin: "https://any-site.io",
342+
method: http.MethodOptions,
343+
requestHeaders: "content-type",
344+
wantOrigin: "https://any-site.io",
345+
wantCreds: "true",
346+
},
347+
{
348+
name: "actual request from localhost",
349+
origin: "http://localhost:3000",
350+
method: http.MethodGet,
351+
wantOrigin: "http://localhost:3000",
352+
wantCreds: "true",
353+
},
354+
}
355+
356+
for _, tt := range tests {
357+
t.Run(tt.name, func(t *testing.T) {
358+
req := httptest.NewRequest(tt.method, "/test", nil)
359+
req.Header.Set("Origin", tt.origin)
360+
if tt.method == http.MethodOptions {
361+
req.Header.Set("Access-Control-Request-Method", "POST")
362+
if tt.requestHeaders != "" {
363+
req.Header.Set("Access-Control-Request-Headers", tt.requestHeaders)
364+
}
365+
}
366+
367+
rr := httptest.NewRecorder()
368+
handler.ServeHTTP(rr, req)
369+
370+
// Verify origin is reflected back (not "*" since credentials are enabled)
371+
gotOrigin := rr.Header().Get("Access-Control-Allow-Origin")
372+
require.Equal(t, tt.wantOrigin, gotOrigin,
373+
"Origin should be reflected back, not '*', when credentials are enabled")
374+
375+
// Verify credentials header
376+
gotCreds := rr.Header().Get("Access-Control-Allow-Credentials")
377+
require.Equal(t, tt.wantCreds, gotCreds)
378+
379+
// For preflight, verify allowed headers
380+
if tt.method == http.MethodOptions {
381+
gotHeaders := rr.Header().Get("Access-Control-Allow-Headers")
382+
require.NotEmpty(t, gotHeaders, "Preflight should include allowed headers")
383+
}
384+
})
385+
}
386+
}
387+
388+
// TestCORSMiddleware_WildcardWithSpecificOrigins tests that wildcard takes precedence
389+
// when mixed with specific origins - all origins are allowed if "*" is in the list.
390+
func TestCORSMiddleware_WildcardWithSpecificOrigins(t *testing.T) {
391+
cfg := CORSConfig{
392+
AllowedOrigins: []string{"https://specific.com", "*", "https://another.com"},
393+
AllowedMethods: []string{"GET", "POST"},
394+
AllowedHeaders: []string{"Authorization"},
395+
AllowCredentials: true,
396+
}
397+
398+
corsHandler := cors.New(cors.Options{
399+
AllowOriginFunc: func(_ *http.Request, origin string) bool {
400+
for _, allowedOrigin := range cfg.AllowedOrigins {
401+
if allowedOrigin == "*" {
402+
return true
403+
}
404+
if strings.EqualFold(origin, allowedOrigin) {
405+
return true
406+
}
407+
}
408+
return false
409+
},
410+
AllowedMethods: cfg.AllowedMethods,
411+
AllowedHeaders: cfg.AllowedHeaders,
412+
AllowCredentials: cfg.AllowCredentials,
413+
})
414+
415+
handler := corsHandler.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
416+
w.WriteHeader(http.StatusOK)
417+
}))
418+
419+
// When "*" is in the list, ANY origin should be allowed
420+
tests := []struct {
421+
name string
422+
origin string
423+
wantOrigin string
424+
}{
425+
{
426+
name: "specific origin still works",
427+
origin: "https://specific.com",
428+
wantOrigin: "https://specific.com",
429+
},
430+
{
431+
name: "random origin allowed due to wildcard",
432+
origin: "https://random-site.io",
433+
wantOrigin: "https://random-site.io",
434+
},
435+
{
436+
name: "evil origin also allowed due to wildcard",
437+
origin: "https://evil.com",
438+
wantOrigin: "https://evil.com",
439+
},
440+
}
441+
442+
for _, tt := range tests {
443+
t.Run(tt.name, func(t *testing.T) {
444+
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
445+
req.Header.Set("Origin", tt.origin)
446+
req.Header.Set("Access-Control-Request-Method", "GET")
447+
448+
rr := httptest.NewRecorder()
449+
handler.ServeHTTP(rr, req)
450+
451+
gotOrigin := rr.Header().Get("Access-Control-Allow-Origin")
452+
assert.Equal(t, tt.wantOrigin, gotOrigin,
453+
"Wildcard in list should allow ALL origins")
454+
})
455+
}
456+
}
457+
458+
// TestCORSMiddleware_SpecificOrigins tests CORS with specific origin list (not wildcard)
459+
func TestCORSMiddleware_SpecificOrigins(t *testing.T) {
460+
cfg := CORSConfig{
461+
Enabled: true,
462+
AllowedOrigins: []string{"https://allowed.com", "https://also-allowed.com"},
463+
AllowedMethods: []string{"GET", "POST"},
464+
AllowedHeaders: []string{"Authorization"},
465+
AllowCredentials: true,
466+
}
467+
468+
corsHandler := cors.New(cors.Options{
469+
AllowOriginFunc: func(_ *http.Request, origin string) bool {
470+
for _, allowedOrigin := range cfg.AllowedOrigins {
471+
if allowedOrigin == "*" {
472+
return true
473+
}
474+
if strings.EqualFold(origin, allowedOrigin) {
475+
return true
476+
}
477+
}
478+
return false
479+
},
480+
AllowedMethods: cfg.AllowedMethods,
481+
AllowedHeaders: cfg.AllowedHeaders,
482+
AllowCredentials: cfg.AllowCredentials,
483+
})
484+
485+
handler := corsHandler.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
486+
w.WriteHeader(http.StatusOK)
487+
}))
488+
489+
tests := []struct {
490+
name string
491+
origin string
492+
wantOrigin string
493+
}{
494+
{
495+
name: "allowed origin",
496+
origin: "https://allowed.com",
497+
wantOrigin: "https://allowed.com",
498+
},
499+
{
500+
name: "also allowed origin",
501+
origin: "https://also-allowed.com",
502+
wantOrigin: "https://also-allowed.com",
503+
},
504+
{
505+
name: "disallowed origin - no CORS headers",
506+
origin: "https://evil.com",
507+
wantOrigin: "", // No Access-Control-Allow-Origin header
508+
},
509+
}
510+
511+
for _, tt := range tests {
512+
t.Run(tt.name, func(t *testing.T) {
513+
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
514+
req.Header.Set("Origin", tt.origin)
515+
req.Header.Set("Access-Control-Request-Method", "GET")
516+
517+
rr := httptest.NewRecorder()
518+
handler.ServeHTTP(rr, req)
519+
520+
gotOrigin := rr.Header().Get("Access-Control-Allow-Origin")
521+
assert.Equal(t, tt.wantOrigin, gotOrigin)
522+
})
523+
}
524+
}

0 commit comments

Comments
 (0)