Skip to content

Commit

Permalink
Fix support for getting single files from GitHub
Browse files Browse the repository at this point in the history
This currently fails for GitHub (and possibly other Git repositories)
because there is an assumption in the "client.go" file that the "subDir"
segment is always a directory.

However, for this URL:

https://github.com/owner/repo/subdir/file.txt

The "subDir" variable would actually be a file ("subdir/file.txt") and
hence the client should use "copyFile" rather than "copyDir" when
copying the "subDir" to the final destination.

This change fixes that to check whether the "subDir" variable is a file
or a directory, and appropriately uses "copyDir" or "copyFile". It also
returns the single file in the "GetResult" object, if indeed a single
file was requested.

Tests attached.
  • Loading branch information
arikkfir committed Aug 3, 2022
1 parent 31c3313 commit 66abd65
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 4 deletions.
41 changes: 37 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
Expand Down Expand Up @@ -317,11 +318,43 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, *
return nil, &getError{true, err}
}

err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask())
if err != nil {
return nil, &getError{false, err}
if stat, err := os.Stat(subDir); err != nil {
return nil, &getError{false, fmt.Errorf("failed to stat '%s': %w", subDir, err)}
} else if stat.IsDir() {
err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask())
if err != nil {
return nil, &getError{false, err}
}
return &GetResult{req.realDst}, nil
} else {
src, err := os.Open(subDir)
if err != nil {
return nil, &getError{false, fmt.Errorf("failed to open local source file at '%s': %w", subDir, err)}
}
//goland:noinspection GoUnhandledErrorResult
defer src.Close()

target := filepath.Join(req.realDst, filepath.Base(subDir))
dst, err := os.Create(target)
if err != nil {
return nil, &getError{false, fmt.Errorf("failed to open local target file at '%s': %w", target, err)}
}
//goland:noinspection GoUnhandledErrorResult
defer dst.Close()

buf := make([]byte, 1024*20) // 20k buffer should usually suffice for 99% of files
for {
n, err := src.Read(buf)
if err != nil && err != io.EOF {
return nil, &getError{false, fmt.Errorf("failed to read local source file at '%s': %w", subDir, err)}
} else if n == 0 {
break
} else if _, err := dst.Write(buf[:n]); err != nil {
return nil, &getError{false, fmt.Errorf("failed to write to local source file at '%s': %w", target, err)}
}
}
return &GetResult{target}, nil
}
return &GetResult{req.realDst}, nil
}

return &GetResult{req.Dst}, nil
Expand Down
76 changes: 76 additions & 0 deletions get_github_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package getter

import (
"context"
testing_helper "github.com/hashicorp/go-getter/v2/helper/testing"
"os"
"path/filepath"
"testing"
)

const basicMainTFExpectedContents = `# Hello
module "foo" {
source = "./foo"
}
`

func TestGitGetter_githubDirWithModeAny(t *testing.T) {
if !testHasGit {
t.Skip("git not found, skipping")
}

ctx := context.Background()
dst := testing_helper.TempDir(t)
defer os.RemoveAll(dst)

req := &Request{
Src: "git::https://github.com/arikkfir/go-getter.git//testdata/basic?ref=v2",
Dst: dst,
GetMode: ModeAny,
Copy: true,
}
client := Client{}
result, err := client.Get(ctx, req)
if err != nil {
t.Fatalf("Failed fetching GitHub directory: %s", err)
} else if stat, err := os.Stat(result.Dst); err != nil {
t.Fatalf("Failed stat dst at '%s': %s", result.Dst, err)
} else if !stat.IsDir() {
t.Fatalf("Expected '%s' to be a directory", result.Dst)
} else if entries, err := os.ReadDir(result.Dst); err != nil {
t.Fatalf("Failed listing directory '%s': %s", result.Dst, err)
} else if len(entries) != 3 {
t.Fatalf("Expected dir '%s' to contain 3 items: %s", result.Dst, err)
} else {
testing_helper.AssertContents(t, filepath.Join(result.Dst, "main.tf"), basicMainTFExpectedContents)
}
}

func TestGitGetter_githubFileWithModeAny(t *testing.T) {
if !testHasGit {
t.Skip("git not found, skipping")
}

ctx := context.Background()
dst := testing_helper.TempDir(t)
defer os.RemoveAll(dst)

req := &Request{
Src: "git::https://github.com/arikkfir/go-getter.git//testdata/basic/main.tf?ref=v2",
Dst: dst,
GetMode: ModeAny,
Copy: true,
}
client := Client{}
result, err := client.Get(ctx, req)
if err != nil {
t.Fatalf("Failed fetching GitHub file: %s", err)
} else if stat, err := os.Stat(result.Dst); err != nil {
t.Fatalf("Failed stat dst at '%s': %s", result.Dst, err)
} else if stat.IsDir() {
t.Fatalf("Expected '%s' to be a file", result.Dst)
} else {
testing_helper.AssertContents(t, result.Dst, basicMainTFExpectedContents)
}
}

0 comments on commit 66abd65

Please sign in to comment.