From 872897ff964df88995410cf2e7f9249439cf7461 Mon Sep 17 00:00:00 2001 From: Riccardo Piccoli Date: Tue, 30 Mar 2021 22:14:59 +0200 Subject: [PATCH] fix: allow mountPaths with traling slash (#5521) Signed-off-by: Riccardo Piccoli --- workflow/common/util.go | 24 ++++++++++++++---------- workflow/common/util_test.go | 28 ++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/workflow/common/util.go b/workflow/common/util.go index d72e4b7171b0..0cde8ce51ad0 100644 --- a/workflow/common/util.go +++ b/workflow/common/util.go @@ -8,6 +8,7 @@ import ( "net/http" "os/exec" "runtime" + "sort" "strings" "time" @@ -34,18 +35,21 @@ import ( // user specified volumeMounts in the template, and returns the deepest volumeMount // (if any). A return value of nil indicates the path is not under any volumeMount. func FindOverlappingVolume(tmpl *wfv1.Template, path string) *apiv1.VolumeMount { - var volMnt *apiv1.VolumeMount - deepestLen := 0 - for _, mnt := range tmpl.GetVolumeMounts() { - if path != mnt.MountPath && !strings.HasPrefix(path, mnt.MountPath+"/") { - continue - } - if len(mnt.MountPath) > deepestLen { - volMnt = &mnt - deepestLen = len(mnt.MountPath) + volumeMounts := tmpl.GetVolumeMounts() + sort.Slice(volumeMounts, func(i, j int) bool { + return len(volumeMounts[i].MountPath) > len(volumeMounts[j].MountPath) + }) + for _, mnt := range volumeMounts { + normalizedMountPath := strings.TrimRight(mnt.MountPath, "/") + if path == normalizedMountPath || isSubPath(path, normalizedMountPath) { + return &mnt } } - return volMnt + return nil +} + +func isSubPath(path string, normalizedMountPath string) bool { + return strings.HasPrefix(path, normalizedMountPath+"/") } type RoundTripCallback func(conn *websocket.Conn, resp *http.Response, err error) error diff --git a/workflow/common/util_test.go b/workflow/common/util_test.go index 433690674dd3..97a4b0c3b881 100644 --- a/workflow/common/util_test.go +++ b/workflow/common/util_test.go @@ -57,14 +57,38 @@ func TestFindOverlappingVolume(t *testing.T) { Name: "workdir", MountPath: "/user-mount", } + volMntTrailing := corev1.VolumeMount{ + Name: "aux", + MountPath: "/trailing-slash/", + } templateWithVolMount := &wfv1.Template{ Container: &corev1.Container{ - VolumeMounts: []corev1.VolumeMount{volMnt}, + VolumeMounts: []corev1.VolumeMount{volMnt, volMntTrailing}, + }, + } + + deeperVolMnt := corev1.VolumeMount{ + Name: "workdir", + MountPath: "/user-mount/deeper", + } + + templateWithDeeperVolMount := &wfv1.Template{ + Container: &corev1.Container{ + VolumeMounts: []corev1.VolumeMount{volMnt, deeperVolMnt}, }, } + assert.Equal(t, &volMnt, FindOverlappingVolume(templateWithVolMount, "/user-mount")) assert.Equal(t, &volMnt, FindOverlappingVolume(templateWithVolMount, "/user-mount/subdir")) - assert.Nil(t, FindOverlappingVolume(templateWithVolMount, "/user-mount-coincidental-prefix")) + assert.Equal(t, &volMnt, FindOverlappingVolume(templateWithVolMount, "/user-mount/")) + + assert.Equal(t, &deeperVolMnt, FindOverlappingVolume(templateWithDeeperVolMount, "/user-mount/deeper")) + assert.Equal(t, &deeperVolMnt, FindOverlappingVolume(templateWithDeeperVolMount, "/user-mount/deeper/with-subdir")) + + assert.Equal(t, &volMntTrailing, FindOverlappingVolume(templateWithVolMount, "/trailing-slash/")) + assert.Equal(t, &volMntTrailing, FindOverlappingVolume(templateWithVolMount, "/trailing-slash/with-subpath")) + + assert.Nil(t, FindOverlappingVolume(templateWithVolMount, "/user-mount-coincidental-prefix/")) } func TestUnknownFieldEnforcerForWorkflowStep(t *testing.T) {