From 66abd65295a5bcf347fba28725f2a008f071b910 Mon Sep 17 00:00:00 2001 From: Arik Kfir Date: Wed, 3 Aug 2022 19:06:40 +0300 Subject: [PATCH] Fix support for getting single files from GitHub This currently fails for GitHub (and possibly other Git repositories) because there is an assumption in the "client.go" file that the "subDir" segment is always a directory. However, for this URL: https://github.com/owner/repo/subdir/file.txt The "subDir" variable would actually be a file ("subdir/file.txt") and hence the client should use "copyFile" rather than "copyDir" when copying the "subDir" to the final destination. This change fixes that to check whether the "subDir" variable is a file or a directory, and appropriately uses "copyDir" or "copyFile". It also returns the single file in the "GetResult" object, if indeed a single file was requested. Tests attached. --- client.go | 41 ++++++++++++++++++++++--- get_github_test.go | 76 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 get_github_test.go diff --git a/client.go b/client.go index 3aa5dd1d5..ed84783f8 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "io/ioutil" "os" "path/filepath" @@ -317,11 +318,43 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * return nil, &getError{true, err} } - err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask()) - if err != nil { - return nil, &getError{false, err} + if stat, err := os.Stat(subDir); err != nil { + return nil, &getError{false, fmt.Errorf("failed to stat '%s': %w", subDir, err)} + } else if stat.IsDir() { + err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask()) + if err != nil { + return nil, &getError{false, err} + } + return &GetResult{req.realDst}, nil + } else { + src, err := os.Open(subDir) + if err != nil { + return nil, &getError{false, fmt.Errorf("failed to open local source file at '%s': %w", subDir, err)} + } + //goland:noinspection GoUnhandledErrorResult + defer src.Close() + + target := filepath.Join(req.realDst, filepath.Base(subDir)) + dst, err := os.Create(target) + if err != nil { + return nil, &getError{false, fmt.Errorf("failed to open local target file at '%s': %w", target, err)} + } + //goland:noinspection GoUnhandledErrorResult + defer dst.Close() + + buf := make([]byte, 1024*20) // 20k buffer should usually suffice for 99% of files + for { + n, err := src.Read(buf) + if err != nil && err != io.EOF { + return nil, &getError{false, fmt.Errorf("failed to read local source file at '%s': %w", subDir, err)} + } else if n == 0 { + break + } else if _, err := dst.Write(buf[:n]); err != nil { + return nil, &getError{false, fmt.Errorf("failed to write to local source file at '%s': %w", target, err)} + } + } + return &GetResult{target}, nil } - return &GetResult{req.realDst}, nil } return &GetResult{req.Dst}, nil diff --git a/get_github_test.go b/get_github_test.go new file mode 100644 index 000000000..bb4f80087 --- /dev/null +++ b/get_github_test.go @@ -0,0 +1,76 @@ +package getter + +import ( + "context" + testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" + "os" + "path/filepath" + "testing" +) + +const basicMainTFExpectedContents = `# Hello + +module "foo" { + source = "./foo" +} +` + +func TestGitGetter_githubDirWithModeAny(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + ctx := context.Background() + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + req := &Request{ + Src: "git::https://github.com/arikkfir/go-getter.git//testdata/basic?ref=v2", + Dst: dst, + GetMode: ModeAny, + Copy: true, + } + client := Client{} + result, err := client.Get(ctx, req) + if err != nil { + t.Fatalf("Failed fetching GitHub directory: %s", err) + } else if stat, err := os.Stat(result.Dst); err != nil { + t.Fatalf("Failed stat dst at '%s': %s", result.Dst, err) + } else if !stat.IsDir() { + t.Fatalf("Expected '%s' to be a directory", result.Dst) + } else if entries, err := os.ReadDir(result.Dst); err != nil { + t.Fatalf("Failed listing directory '%s': %s", result.Dst, err) + } else if len(entries) != 3 { + t.Fatalf("Expected dir '%s' to contain 3 items: %s", result.Dst, err) + } else { + testing_helper.AssertContents(t, filepath.Join(result.Dst, "main.tf"), basicMainTFExpectedContents) + } +} + +func TestGitGetter_githubFileWithModeAny(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + ctx := context.Background() + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + req := &Request{ + Src: "git::https://github.com/arikkfir/go-getter.git//testdata/basic/main.tf?ref=v2", + Dst: dst, + GetMode: ModeAny, + Copy: true, + } + client := Client{} + result, err := client.Get(ctx, req) + if err != nil { + t.Fatalf("Failed fetching GitHub file: %s", err) + } else if stat, err := os.Stat(result.Dst); err != nil { + t.Fatalf("Failed stat dst at '%s': %s", result.Dst, err) + } else if stat.IsDir() { + t.Fatalf("Expected '%s' to be a file", result.Dst) + } else { + testing_helper.AssertContents(t, result.Dst, basicMainTFExpectedContents) + } +}