diff --git a/cached_downloader.go b/cached_downloader.go index cd25860..840d53e 100644 --- a/cached_downloader.go +++ b/cached_downloader.go @@ -23,7 +23,7 @@ type CachedDownloader interface { } func NoopTransform(source, destination string) (int64, error) { - err := os.Rename(source, destination) + err := replace(source, destination) if err != nil { return 0, err } diff --git a/file_cache.go b/file_cache.go index f949ae0..f13b765 100644 --- a/file_cache.go +++ b/file_cache.go @@ -105,7 +105,7 @@ func (c *FileCache) Add(cacheKey string, sourcePath string, size int64, cachingI uniqueName := fmt.Sprintf("%s-%d-%d", cacheKey, time.Now().UnixNano(), c.seq) cachePath := filepath.Join(c.cachedPath, uniqueName) - err := os.Rename(sourcePath, cachePath) + err := replace(sourcePath, cachePath) if err != nil { return nil, err } diff --git a/replace.go b/replace.go new file mode 100644 index 0000000..b4973a9 --- /dev/null +++ b/replace.go @@ -0,0 +1,9 @@ +// +build !windows + +package cacheddownloader + +import "os" + +func replace(src, dst string) error { + return os.Rename(src, dst) +} diff --git a/replace_windows.go b/replace_windows.go new file mode 100644 index 0000000..6af311c --- /dev/null +++ b/replace_windows.go @@ -0,0 +1,41 @@ +package cacheddownloader + +import ( + "syscall" + "unsafe" +) + +func replace(src, dst string) error { + kernel32, err := syscall.LoadLibrary("kernel32.dll") + if err != nil { + return err + } + defer syscall.FreeLibrary(kernel32) + moveFileExUnicode, err := syscall.GetProcAddress(kernel32, "MoveFileExW") + if err != nil { + return err + } + + srcString, err := syscall.UTF16PtrFromString(src) + if err != nil { + return err + } + + dstString, err := syscall.UTF16PtrFromString(dst) + if err != nil { + return err + } + + srcPtr := uintptr(unsafe.Pointer(srcString)) + dstPtr := uintptr(unsafe.Pointer(dstString)) + + MOVEFILE_REPLACE_EXISTING := 0x1 + flag := uintptr(MOVEFILE_REPLACE_EXISTING) + + _, _, callErr := syscall.Syscall(uintptr(moveFileExUnicode), 3, srcPtr, dstPtr, flag) + if callErr != 0 { + return callErr + } + + return nil +}