From 3fb048af5c63bc9545de61ea7691ac13492a1fb3 Mon Sep 17 00:00:00 2001 From: Kevin Parsons Date: Sun, 25 Sep 2022 11:02:48 -0700 Subject: [PATCH] pkg/fs: Add RemoveAll function New RemoveAll function intended to be a replacement for os.RemoveAll usage and be tailored to Windows scenarios. Most importantly, this function works with []uint16 instead of string. This allows it to properly delete Windows filesystem trees that contain files with invalid UTF-16 paths. Signed-off-by: Kevin Parsons --- pkg/fs/removeall_windows.go | 214 +++++++++++++++++++++++++++++ pkg/fs/removeall_windows_test.go | 222 +++++++++++++++++++++++++++++++ pkg/fs/syscall_windows.go | 3 + pkg/fs/zsyscall_windows.go | 64 +++++++++ syscall.go => syscall_windows.go | 2 - 5 files changed, 503 insertions(+), 2 deletions(-) create mode 100644 pkg/fs/removeall_windows.go create mode 100644 pkg/fs/removeall_windows_test.go create mode 100644 pkg/fs/syscall_windows.go create mode 100644 pkg/fs/zsyscall_windows.go rename syscall.go => syscall_windows.go (85%) diff --git a/pkg/fs/removeall_windows.go b/pkg/fs/removeall_windows.go new file mode 100644 index 00000000..250cba53 --- /dev/null +++ b/pkg/fs/removeall_windows.go @@ -0,0 +1,214 @@ +package fs + +import ( + "errors" + "io/fs" + "unsafe" + + "golang.org/x/sys/windows" +) + +// We want to use findFirstFileExW since it allows us to avoid the shortname lookup. +// Package golang.org/x/sys/windows only defines FindFirstFile, so we define our own here. +// We could use its findNextFileW definition, except that it also has an inaccurate version of Win32finddata[1]. +// Therefore we define our own versions of both, as well as our own findData struct. +// +// [1] This inaccurate version is still usable, as the name and alternate name fields are one character short, +// but the original value is guaranteed to be null terminated, so you just lose the null terminator in the worst case. +// However, this is still annoying to deal with, as you either must find the null in the array, or treat the full array +// as a null terminated string. + +//sys findFirstFileExW(pattern *uint16, infoLevel uint32, data *findData, searchOp uint32, searchFilter unsafe.Pointer, flags uint32) (h windows.Handle, err error) [failretval==windows.InvalidHandle] = kernel32.FindFirstFileExW +//sys findNextFileW(findHandle windows.Handle, data *findData) (err error) = kernel32.FindNextFileW + +const ( + // Part of FINDEX_INFO_LEVELS. + // + // https://learn.microsoft.com/en-us/windows/win32/api/minwinbase/ne-minwinbase-findex_info_levels + findExInfoBasic = 1 + + // Part of FINDEX_SEARCH_OPS. + // + // https://learn.microsoft.com/en-us/windows/win32/api/minwinbase/ne-minwinbase-findex_search_ops + findExSearchNameMatch = 0 + + // Indicates the reparse point is a name surrogate, which means it refers to another filesystem entry. + // This is used for things like symlinks and junction (mount) points. + // + // https://learn.microsoft.com/en-us/windows/win32/fileio/reparse-point-tags#tag-contents + reparseTagNameSurrogate = 0x20000000 +) + +// findData represents Windows's WIN32_FIND_DATAW type. +// The obsolete items at the end of the struct in the docs are not actually present, except on Mac. +// +// https://learn.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-win32_find_dataw +type findData struct { + attributes uint32 + creationTime windows.Filetime + accessTime windows.Filetime + writeTime windows.Filetime + fileSizeHigh uint32 + fileSizeLow uint32 + reserved0 uint32 // Holds reparse point tag when file is a reparse point. + reserved1 uint32 + name [windows.MAX_PATH]uint16 + alternateName [14]uint16 +} + +// fileAttributeTagInfo represents Windows's FILE_ATTRIBUTE_TAG_INFO type. +// +// https://learn.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_attribute_tag_infof +type fileAttributeTagInfo struct { + attributes uint32 + tag uint32 +} + +var ( + // Mockable for testing. + removeDirectory = windows.RemoveDirectory +) + +// RemoveAll attempts to be a Windows-specific replacement for os.RemoveAll. +// Specifically, it treats a filesystem path as a []uint16 rather than a string. This allows it to work in some cases +// where a string-based API would not, such as when a filesystem name is not valid UTF-16 (e.g. low surrogates without preceding high surrogates). +// +// RemoveAll handles some dynamic changes to the directory tree while we are deleting it. For instance, +// if a directory is not empty, we delete its children first, and then try again. If new items are added to the +// directory while we do that, we will again recurse and delete them, and keep trying. This matches the behavior of +// os.RemoveAll. +// However, we don't attempt to handle every dynamic case, for instance: +// - If a file has FILE_ATTRIBUTE_READONLY re-set on it after we clear it, we will fail. +// - If an entry is changed from file to directory after we query its attributes, we will fail. +func RemoveAll(path []uint16) error { + attrs, reparseTag, err := getFileInfo(path) + if errors.Is(err, windows.ERROR_FILE_NOT_FOUND) { + return nil + } else if err != nil { + return err + } + return removeAll(path, attrs, reparseTag) +} + +// removeAll is the lower-level routine for recursively deleting a file system entry. +// The attributes and reparse point of the top-level item to be deleted must have already been queried and are supplied as an argument. +func removeAll(path []uint16, attrs uint32, reparseTag uint32) error { + if attrs&windows.FILE_ATTRIBUTE_DIRECTORY == 0 { + // File + if attrs&windows.FILE_ATTRIBUTE_READONLY != 0 { + // Read-only flag prevents deletion, so un-set it first. + if err := windows.SetFileAttributes(terminate(path), attrs&^windows.FILE_ATTRIBUTE_READONLY); err == windows.ERROR_FILE_NOT_FOUND { + return nil + } else if err != nil { + return pathErr("SetFileAttributes", path, err) + } + } + if err := windows.DeleteFile(terminate(path)); err == windows.ERROR_FILE_NOT_FOUND { + return nil + } else if err != nil { + return pathErr("DeleteFile", path, err) + } + } else { + // Directory + // Keep looping, removing children, and attempting to delete again. + for { + // First, try to delete the directory. This will only work if it's empty. + // If that fails then enumerate the entries and delete them first. + if err := removeDirectory(terminate(path)); err == nil || err == windows.ERROR_FILE_NOT_FOUND { + return nil + } else if err != windows.ERROR_DIR_NOT_EMPTY || (attrs&windows.FILE_ATTRIBUTE_REPARSE_POINT != 0 && reparseTag&reparseTagNameSurrogate != 0) { + // We either failed for some reason other than the directory not being empty, or because the directory is a name surrogate reparse point. + // We don't want to recurse into a name surrogate (e.g. symlink) because we will end up deleting stuff elsewhere on the system. + // In this case there's nothing else to do for this entry, so just return an error. + return pathErr("RemoveDirectory", path, err) + } + // Wrapping in an anonymous function so that the deferred FindClose call is scoped properly. + // Iterate the directory and remove all children by recursing into removeAll. + if err := func() error { + var fd findData + // findFirstFileExW allows us to avoid the shortname lookup. See comment at top of file for details. + pattern := join(path, []uint16{'*'}) + find, err := findFirstFileExW(terminate(pattern), findExInfoBasic, &fd, findExSearchNameMatch, nil, 0) + if err == windows.ERROR_FILE_NOT_FOUND { + // No children is weird, because we should always have "." and "..". + // If we do hit this, just continue to the next deletion attempt. + return nil + } else if err != nil { + return pathErr("FindFirstFileEx", pattern, err) + } + defer windows.FindClose(find) + for { + var child []uint16 + child, err = truncAtNull(fd.name[:]) + if err != nil { + return err + } + if !equal(child, []uint16{'.'}) && !equal(child, []uint16{'.', '.'}) { // Ignore "." and ".." + if err := removeAll(join(path, child), fd.attributes, fd.reserved0); err != nil { + return err + } + } + err = findNextFileW(find, &fd) + if err == windows.ERROR_NO_MORE_FILES { + break + } else if err != nil { + return pathErr("FindNextFile", path, err) + } + } + return nil + }(); err != nil { + return err + } + } + } + return nil +} + +func getFileInfo(path []uint16) (attrs uint32, reparseTag uint32, _ error) { + h, err := windows.CreateFile(terminate(path), 0, 0, nil, windows.OPEN_EXISTING, windows.FILE_OPEN_REPARSE_POINT|windows.FILE_FLAG_BACKUP_SEMANTICS, 0) + if err != nil { + return 0, 0, pathErr("CreateFile", path, err) + } + defer windows.CloseHandle(h) //nolint:errcheck + var ti fileAttributeTagInfo + if err = windows.GetFileInformationByHandleEx(h, windows.FileAttributeTagInfo, (*byte)(unsafe.Pointer(&ti)), uint32(unsafe.Sizeof(ti))); err != nil { + return 0, 0, pathErr("GetFileInformationByHandleEx", path, err) + } + return ti.attributes, ti.tag, nil +} + +func equal(v1, v2 []uint16) bool { + if len(v1) != len(v2) { + return false + } + for i := range v1 { + if v1[i] != v2[i] { + return false + } + } + return true +} + +func join(parent, child []uint16) []uint16 { + return append(append(parent, '\\'), child...) +} + +func pathErr(op string, path []uint16, err error) error { + return &fs.PathError{Op: op, Path: windows.UTF16ToString(path), Err: err} +} + +// terminate takes a []uint16 and returns a null-terminated *uint16. +func terminate(path []uint16) *uint16 { + return &append(path, '\u0000')[0] +} + +// truncAtNull searches the input for a null terminator and returns a slice +// up to that point. It returns an error if there is no null terminator in the input. +func truncAtNull(path []uint16) ([]uint16, error) { + for i, u := range path { + if u == '\u0000' { + return path[:i], nil + } + } + return nil, errors.New("path is not null terminated") +} diff --git a/pkg/fs/removeall_windows_test.go b/pkg/fs/removeall_windows_test.go new file mode 100644 index 00000000..ee059cc5 --- /dev/null +++ b/pkg/fs/removeall_windows_test.go @@ -0,0 +1,222 @@ +package fs + +import ( + "errors" + "io/fs" + "os" + "os/exec" + "path/filepath" + "testing" + "unicode/utf16" + + "golang.org/x/sys/windows" +) + +type fsEntry interface { + create(parent []uint16) error +} + +type file struct { + name string + rawName []uint16 + readOnly bool +} + +func (f file) create(parent []uint16) error { + if f.name != "" && f.rawName != nil { + return errors.New("cannot set both name and rawName") + } + name := f.rawName + if f.name != "" { + name = utf16.Encode([]rune(f.name)) + } + p := join(parent, name) + var attrs uint32 + if f.readOnly { + attrs |= windows.FILE_ATTRIBUTE_READONLY + } + h, err := windows.CreateFile( + terminate(p), + windows.GENERIC_ALL, + 0, + nil, + windows.CREATE_NEW, + attrs, + 0) + if err != nil { + return pathErr("CreateFile", p, err) + } + return windows.CloseHandle(h) +} + +type dir struct { + name string + rawName []uint16 + children []fsEntry +} + +func (d dir) create(parent []uint16) error { + if d.name != "" && d.rawName != nil { + return errors.New("cannot set both name and rawName") + } + name := d.rawName + if d.name != "" { + name = utf16.Encode([]rune(d.name)) + } + p := join(parent, name) + if err := windows.CreateDirectory(terminate(p), nil); err != nil { + return pathErr("CreateDirectory", p, err) + } + for _, c := range d.children { + if err := c.create(p); err != nil { + return err + } + } + return nil +} + +// We use junctions instead of symlinks because CreateSymbolicLink requires either +// Administrator privileges or Developer Mode enabled, which are both annoying to +// require for test code to run. +type junction struct { + name string + target string +} + +func (j junction) create(parent []uint16) error { + // There isn't a simple Windows API to create a junction, so we shell out for mklink instead. + // The alternative would be manually creating the reparse point buffer and calling + // the fsctl, which is too annoying for test code. + p := filepath.Join(windows.UTF16ToString(parent), j.name) + c := exec.Command("cmd.exe", "/c", "mklink", "/J", p, j.target) + if err := c.Run(); err != nil { + return &os.LinkError{Op: "mklink", New: p, Old: j.target, Err: err} + } + return nil +} + +// TestRemoveAll creates a series of nested filesystem entries (files, directories, and junctions) beneath a temp root, +// then calls RemoveAll on the root, then tests to ensure the contents were deleted. +func TestRemoveAll(t *testing.T) { + root := t.TempDir() + t.Logf("Root directory: %s", root) + rootU16 := utf16.Encode([]rune(root)) + + entries := []fsEntry{ + dir{name: "dir", children: []fsEntry{ + dir{name: "childdir", children: []fsEntry{ + file{name: "bar.txt"}, + }}, + file{name: "baz.txt"}, + }}, + dir{name: "emptydir"}, + dir{name: "fakeemptydir", children: []fsEntry{ + file{name: "thisfilewillbedeleted"}, + }}, + file{name: "foo.txt"}, + // This file name was seen in a real case where os.RemoveAll failed. It is invalid UTF-16 as it contains low surrogates that are not preceded by high surrogates ([1:5]). + file{rawName: []uint16{0x2e, 0xdc6d, 0xdc73, 0xdc79, 0xdc73, 0x30, 0x30, 0x30, 0x31, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x35, 0x36, 0x33, 0x39, 0x64, 0x64, 0x35, 0x30, 0x61, 0x37, 0x32, 0x61, 0x62, 0x34, 0x36, 0x36, 0x38, 0x62, 0x33, 0x33}}, + file{name: "readonlyfile", readOnly: true}, + } + for _, entry := range entries { + if err := entry.create(rootU16); err != nil { + t.Fatal(err) + } + } + + if err := RemoveAll(rootU16); err != nil { + t.Fatal(err) + } + + if _, err := os.Lstat(root); !os.IsNotExist(err) { + t.Errorf("root dir exists when it should not: %s", root) + } +} + +// TestRemoveAllDontFollowSymlinks creates a junction pointing to another directory, then calls RemoveAll on it, then tests +// to ensure the referenced directory or its contents were not deleted. +func TestRemoveAllDontFollowSymlinks(t *testing.T) { + root := t.TempDir() + t.Logf("Root directory: %s", root) + rootU16 := utf16.Encode([]rune(root)) + junctionDir := t.TempDir() + t.Logf("Junction directory: %s", junctionDir) + // We will later test to ensure fileinjunction still exists. + if err := (file{name: "fileinjunction"}.create(utf16.Encode([]rune(junctionDir)))); err != nil { + t.Fatal(err) + } + + entries := []fsEntry{ + junction{name: "link", target: junctionDir}, + } + for _, entry := range entries { + if err := entry.create(rootU16); err != nil { + t.Fatal(err) + } + } + + if err := RemoveAll(rootU16); err != nil { + t.Errorf("RemoveAll failed: %s", err) + } + + if _, err := os.Lstat(root); !os.IsNotExist(err) { + t.Errorf("root dir exists when it should not: %s", root) + } + if _, err := os.Lstat(junctionDir); err != nil { + t.Errorf("junction dir may have been deleted when it should not be: %s", err) + } + if _, err := os.Lstat(filepath.Join(junctionDir, "fileinjunction")); err != nil { + t.Errorf("file in junction dir may have been deleted when it should not be: %s", err) + } +} + +// TestRemoveAllShouldFailWhenSymlinkDeletionFails creates a junction pointing to another directory, then calls RemoveAll on it, then tests +// to ensure the referenced directory or its contents were not deleted. However in this case we mock the test so that RemoveDirectory fails +// the first time on the junction. This is to ensure we don't accidentally recurse into the symlink when this happens. +func TestRemoveAllShouldFailWhenSymlinkDeletionFails(t *testing.T) { + root := t.TempDir() + t.Logf("Root directory: %s", root) + rootU16 := utf16.Encode([]rune(root)) + junctionDir := t.TempDir() + t.Logf("Junction directory: %s", junctionDir) + // We will later test to ensure fileinjunction still exists. + if err := (file{name: "fileinjunction"}.create(utf16.Encode([]rune(junctionDir)))); err != nil { + t.Fatal(err) + } + + entries := []fsEntry{ + junction{name: "link", target: junctionDir}, + } + for _, entry := range entries { + if err := entry.create(rootU16); err != nil { + t.Errorf("RemoveAll failed: %s", err) + } + } + + var linkDeleteAttempted bool + removeDirectory = func(path *uint16) error { + if _, name := filepath.Split(windows.UTF16PtrToString(path)); !linkDeleteAttempted && name == "link" { + linkDeleteAttempted = true + return windows.ERROR_DIR_NOT_EMPTY + } + return windows.RemoveDirectory(path) + } + + if err := RemoveAll(rootU16); err != nil { + // We expect RemoveAll to return an error as it failed to delete the link. + pathErr, ok := err.(*fs.PathError) + if !ok { + t.Errorf("RemoveAll failed: %s", err) + } + if _, name := filepath.Split(pathErr.Path); pathErr.Op != "RemoveDirectory" || pathErr.Err != windows.ERROR_DIR_NOT_EMPTY || name != "link" { + t.Errorf("RemoveAll failed: %s", err) + } + } + + if _, err := os.Lstat(junctionDir); err != nil { + t.Errorf("junction dir may have been deleted when it should not be: %s", err) + } + if _, err := os.Lstat(filepath.Join(junctionDir, "fileinjunction")); err != nil { + t.Errorf("file in junction dir may have been deleted when it should not be: %s", err) + } +} diff --git a/pkg/fs/syscall_windows.go b/pkg/fs/syscall_windows.go new file mode 100644 index 00000000..88f87586 --- /dev/null +++ b/pkg/fs/syscall_windows.go @@ -0,0 +1,3 @@ +package fs + +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go diff --git a/pkg/fs/zsyscall_windows.go b/pkg/fs/zsyscall_windows.go new file mode 100644 index 00000000..c7e27662 --- /dev/null +++ b/pkg/fs/zsyscall_windows.go @@ -0,0 +1,64 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package fs + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procFindFirstFileExW = modkernel32.NewProc("FindFirstFileExW") + procFindNextFileW = modkernel32.NewProc("FindNextFileW") +) + +func findFirstFileExW(pattern *uint16, infoLevel uint32, data *findData, searchOp uint32, searchFilter unsafe.Pointer, flags uint32) (h windows.Handle, err error) { + r0, _, e1 := syscall.Syscall6(procFindFirstFileExW.Addr(), 6, uintptr(unsafe.Pointer(pattern)), uintptr(infoLevel), uintptr(unsafe.Pointer(data)), uintptr(searchOp), uintptr(searchFilter), uintptr(flags)) + h = windows.Handle(r0) + if h == windows.InvalidHandle { + err = errnoErr(e1) + } + return +} + +func findNextFileW(findHandle windows.Handle, data *findData) (err error) { + r1, _, e1 := syscall.Syscall(procFindNextFileW.Addr(), 2, uintptr(findHandle), uintptr(unsafe.Pointer(data)), 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} diff --git a/syscall.go b/syscall_windows.go similarity index 85% rename from syscall.go rename to syscall_windows.go index a6ca111b..ca0de234 100644 --- a/syscall.go +++ b/syscall_windows.go @@ -1,5 +1,3 @@ -//go:build windows - package winio //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go