diff --git a/join.go b/join.go index aa32b85..3080950 100644 --- a/join.go +++ b/join.go @@ -11,7 +11,6 @@ package securejoin import ( - "bytes" "errors" "os" "path/filepath" @@ -19,6 +18,8 @@ import ( "syscall" ) +const maxSymlinkLimit = 255 + // IsNotExist tells you if err is an error that implies that either the path // accessed does not exist (or path components don't exist). This is // effectively a more broad version of os.IsNotExist. @@ -51,71 +52,68 @@ func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { } unsafePath = filepath.FromSlash(unsafePath) - var path bytes.Buffer - n := 0 + var ( + currentPath string + linksWalked int + ) for unsafePath != "" { - if n > 255 { - return "", &os.PathError{Op: "SecureJoin", Path: root + string(filepath.Separator) + unsafePath, Err: syscall.ELOOP} - } - if v := filepath.VolumeName(unsafePath); v != "" { unsafePath = unsafePath[len(v):] } - // Next path component, p. - i := strings.IndexRune(unsafePath, filepath.Separator) - var p string - if i == -1 { - p, unsafePath = unsafePath, "" + // Get the next path component. + var part string + if i := strings.IndexRune(unsafePath, filepath.Separator); i == -1 { + part, unsafePath = unsafePath, "" } else { - p, unsafePath = unsafePath[:i], unsafePath[i+1:] + part, unsafePath = unsafePath[:i], unsafePath[i+1:] } - // Create a cleaned path, using the lexical semantics of /../a, to - // create a "scoped" path component which can safely be joined to fullP - // for evaluation. At this point, path.String() doesn't contain any - // symlink components. - cleanP := filepath.Clean(string(filepath.Separator) + path.String() + p) - if cleanP == string(filepath.Separator) { - path.Reset() + // Apply the component lexically to the path we are building. + // currentPath does not contain any symlinks, and we are lexically + // dealing with a single component, so it's okay to do a filepath.Clean + // here. + nextPath := filepath.Join(string(filepath.Separator), currentPath, part) + if nextPath == string(filepath.Separator) { + currentPath = "" continue } - fullP := filepath.Clean(root + cleanP) + fullPath := root + string(filepath.Separator) + nextPath // Figure out whether the path is a symlink. - fi, err := vfs.Lstat(fullP) + fi, err := vfs.Lstat(fullPath) if err != nil && !IsNotExist(err) { return "", err } // Treat non-existent path components the same as non-symlinks (we // can't do any better here). if IsNotExist(err) || fi.Mode()&os.ModeSymlink == 0 { - path.WriteString(p) - path.WriteRune(filepath.Separator) + currentPath = nextPath continue } - // Only increment when we actually dereference a link. - n++ + // It's a symlink, so get its contents and expand it by prepending it + // to the yet-unparsed path. + linksWalked++ + if linksWalked > maxSymlinkLimit { + return "", &os.PathError{Op: "SecureJoin", Path: root + string(filepath.Separator) + unsafePath, Err: syscall.ELOOP} + } - // It's a symlink, expand it by prepending it to the yet-unparsed path. - dest, err := vfs.Readlink(fullP) + dest, err := vfs.Readlink(fullPath) if err != nil { return "", err } + unsafePath = dest + string(filepath.Separator) + unsafePath // Absolute symlinks reset any work we've already done. if filepath.IsAbs(dest) { - path.Reset() + currentPath = "" } - unsafePath = dest + string(filepath.Separator) + unsafePath } - // We have to clean path.String() here because it may contain '..' - // components that are entirely lexical, but would be misleading otherwise. - // And finally do a final clean to ensure that root is also lexically - // clean. - fullP := filepath.Clean(string(filepath.Separator) + path.String()) - return filepath.Clean(root + fullP), nil + // There should be no lexical components like ".." left in the path here, + // but for safety clean up the path before joining it to the root. + finalPath := filepath.Join(string(filepath.Separator), currentPath) + return filepath.Join(root, finalPath), nil } // SecureJoin is a wrapper around SecureJoinVFS that just uses the os.* library