diff --git a/pkg/util/net/http.go b/pkg/util/net/http.go index 77488388..41be5b2f 100644 --- a/pkg/util/net/http.go +++ b/pkg/util/net/http.go @@ -26,6 +26,7 @@ import ( "net/http" "net/url" "os" + "path" "strconv" "strings" @@ -33,6 +34,26 @@ import ( "golang.org/x/net/http2" ) +// JoinPreservingTrailingSlash does a path.Join of the specified elements, +// preserving any trailing slash on the last non-empty segment +func JoinPreservingTrailingSlash(elem ...string) string { + // do the basic path join + result := path.Join(elem...) + + // find the last non-empty segment + for i := len(elem) - 1; i >= 0; i-- { + if len(elem[i]) > 0 { + // if the last segment ended in a slash, ensure our result does as well + if strings.HasSuffix(elem[i], "/") && !strings.HasSuffix(result, "/") { + result += "/" + } + break + } + } + + return result +} + // IsProbableEOF returns true if the given error resembles a connection termination // scenario that would justify assuming that the watch is empty. // These errors are what the Go http stack returns back to us which are general diff --git a/pkg/util/net/http_test.go b/pkg/util/net/http_test.go index 2d41eda4..8f5dd9cd 100644 --- a/pkg/util/net/http_test.go +++ b/pkg/util/net/http_test.go @@ -20,6 +20,7 @@ package net import ( "crypto/tls" + "fmt" "net" "net/http" "net/url" @@ -218,3 +219,40 @@ func TestTLSClientConfigHolder(t *testing.T) { t.Errorf("didn't find tls config") } } + +func TestJoinPreservingTrailingSlash(t *testing.T) { + tests := []struct { + a string + b string + want string + }{ + // All empty + {"", "", ""}, + + // Empty a + {"", "/", "/"}, + {"", "foo", "foo"}, + {"", "/foo", "/foo"}, + {"", "/foo/", "/foo/"}, + + // Empty b + {"/", "", "/"}, + {"foo", "", "foo"}, + {"/foo", "", "/foo"}, + {"/foo/", "", "/foo/"}, + + // Both populated + {"/", "/", "/"}, + {"foo", "foo", "foo/foo"}, + {"/foo", "/foo", "/foo/foo"}, + {"/foo/", "/foo/", "/foo/foo/"}, + } + for _, tt := range tests { + name := fmt.Sprintf("%q+%q=%q", tt.a, tt.b, tt.want) + t.Run(name, func(t *testing.T) { + if got := JoinPreservingTrailingSlash(tt.a, tt.b); got != tt.want { + t.Errorf("JoinPreservingTrailingSlash() = %v, want %v", got, tt.want) + } + }) + } +}