Skip to content

Commit

Permalink
refactor and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
patilpankaj212 committed Jan 27, 2021
1 parent 5f706bd commit 90fe2cb
Show file tree
Hide file tree
Showing 8 changed files with 421 additions and 62 deletions.
18 changes: 8 additions & 10 deletions pkg/downloader/getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ var (
ErrInvalidRemoteType = fmt.Errorf("supplied remote type is not supported")
)

// NewGoGetter returns a new GoGetter struct
func NewGoGetter() *GoGetter {
return &GoGetter{
cache: make(map[string]string),
}
// newGoGetter returns a new GoGetter struct
func newGoGetter() *goGetter {
return &goGetter{}
}

// GetURLSubDir returns the download URL with it's respective type prefix
// along with subDir path, if present.
func (g *GoGetter) GetURLSubDir(remoteURL, destPath string) (string, string, error) {
func (g *goGetter) GetURLSubDir(remoteURL, destPath string) (string, string, error) {

// get subDir, if present
repoURL, subDir := SplitAddrSubdir(remoteURL)
Expand Down Expand Up @@ -76,7 +74,7 @@ func (g *GoGetter) GetURLSubDir(remoteURL, destPath string) (string, string, err
// Download retrieves the remote repository referenced in the given remoteURL
// into the destination path and then returns the full path to any subdir
// indicated in the URL
func (g *GoGetter) Download(remoteURL, destPath string) (string, error) {
func (g *goGetter) Download(remoteURL, destPath string) (string, error) {

zap.S().Debugf("download with remote url: %q, destination dir: %q",
remoteURL, destPath)
Expand Down Expand Up @@ -134,7 +132,7 @@ func (g *GoGetter) Download(remoteURL, destPath string) (string, error) {
//
// DownloadWithType enforces download type on go-getter to get rid of any
// ambiguities in remoteURL
func (g *GoGetter) DownloadWithType(remoteType, remoteURL, destPath string) (string, error) {
func (g *goGetter) DownloadWithType(remoteType, remoteURL, destPath string) (string, error) {

zap.S().Debugf("download with remote type: %q, remote URL: %q, destination dir: %q",
remoteType, remoteURL, destPath)
Expand Down Expand Up @@ -168,7 +166,7 @@ func (g *GoGetter) DownloadWithType(remoteType, remoteURL, destPath string) (str
}
versionConstraints.Required = versionConstraint
}
return g.DownloadRemoteModule(versionConstraints, destPath, module)
return NewRemoteDownloader().DownloadRemoteModule(versionConstraints, destPath, module)
}
return "", fmt.Errorf("%s, is not a valid terraform registry", remoteURL)
}
Expand All @@ -180,7 +178,7 @@ func (g *GoGetter) DownloadWithType(remoteType, remoteURL, destPath string) (str
}

// SubDirGlob returns the actual subdir with globbing processed
func (g *GoGetter) SubDirGlob(destPath, subDir string) (string, error) {
func (g *goGetter) SubDirGlob(destPath, subDir string) (string, error) {
return getter.SubdirGlob(destPath, subDir)
}

Expand Down
71 changes: 60 additions & 11 deletions pkg/downloader/getter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ var (
func TestNewGoGetter(t *testing.T) {
t.Run("new GoGetter", func(t *testing.T) {
var (
want = &GoGetter{
cache: make(map[string]string),
}
got = NewGoGetter()
want = &goGetter{}
got = newGoGetter()
)
if !reflect.DeepEqual(got, want) {
t.Errorf("got: '%v', want: '%v'", got, want)
Expand Down Expand Up @@ -128,7 +126,7 @@ func TestGetURLSubDir(t *testing.T) {
}

for _, tt := range table {
g := NewGoGetter()
g := newGoGetter()
gotURL, gotSubDir, gotErr := g.GetURLSubDir(tt.URL, tt.dest)
if !reflect.DeepEqual(gotURL, tt.wantURL) {
t.Errorf("url got: '%v', want: '%v'", gotURL, tt.wantURL)
Expand All @@ -150,6 +148,8 @@ func TestDownload(t *testing.T) {
dest string
wantDest string
wantErr error
// when error is expected, but assertion is not required
skipErrAssert bool
}{
{
name: "empty URL",
Expand All @@ -172,14 +172,27 @@ func TestDownload(t *testing.T) {
wantDest: "",
wantErr: fmt.Errorf("GitHub URLs should be github.com/username/repo"),
},
{
name: "valid url, non existing repo",
URL: "github.com/testuser/testrepo",
dest: someDest,
wantDest: "",
skipErrAssert: true,
},
}

for _, tt := range table {
t.Run(tt.name, func(t *testing.T) {
g := NewGoGetter()
g := newGoGetter()
gotDest, gotErr := g.Download(tt.URL, tt.dest)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("error got: '%v', want: '%v'", gotErr, tt.wantErr)
if !tt.skipErrAssert {
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("error got: '%v', want: '%v'", gotErr, tt.wantErr)
}
} else {
if gotErr == nil {
t.Error("error expected")
}
}
if !reflect.DeepEqual(gotDest, tt.wantDest) {
t.Errorf("dest got: '%v', want: '%v'", gotDest, tt.wantDest)
Expand All @@ -190,13 +203,19 @@ func TestDownload(t *testing.T) {

func TestDownloadWithType(t *testing.T) {

remoteTypeTerraformRegistry := "terraform-registry"
testInvalidRegistrySource := "test/some-url"
testValidNonExistentRegistrySource := "terraform-aws-modules/xyz/aws:1.0.0"

table := []struct {
name string
Type string
URL string
dest string
wantDest string
wantErr error
// when error is expected, but assertion is not required
skipErrAssert bool
}{
{
name: "empty URL and Type",
Expand Down Expand Up @@ -246,14 +265,44 @@ func TestDownloadWithType(t *testing.T) {
wantDest: "",
wantErr: fmt.Errorf("GitHub URLs should be github.com/username/repo"),
},
{
name: "terraform-registry remote type with invalid source addr",
Type: remoteTypeTerraformRegistry,
URL: testInvalidRegistrySource,
dest: someDest,
wantDest: "",
wantErr: fmt.Errorf("%s, is not a valid terraform registry", testInvalidRegistrySource),
},
{
name: "terraform-registry remote type with valid non-existent source addr",
Type: remoteTypeTerraformRegistry,
URL: testValidNonExistentRegistrySource,
dest: someDest,
wantDest: "",
skipErrAssert: true,
},
{
name: "terraform-registry remote type with invalid version",
Type: remoteTypeTerraformRegistry,
URL: "terraform-aws-modules/xyz/aws:x.y",
dest: someDest,
wantDest: "",
skipErrAssert: true,
},
}

for _, tt := range table {
t.Run(tt.name, func(t *testing.T) {
g := NewGoGetter()
g := newGoGetter()
gotDest, gotErr := g.DownloadWithType(tt.Type, tt.URL, tt.dest)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("error got: '%v', want: '%v'", gotErr, tt.wantErr)
if !tt.skipErrAssert {
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("error got: '%v', want: '%v'", gotErr, tt.wantErr)
}
} else {
if gotErr == nil {
t.Error("error expected")
}
}
if !reflect.DeepEqual(gotDest, tt.wantDest) {
t.Errorf("dest got: '%v', want: '%v'", gotDest, tt.wantDest)
Expand Down
23 changes: 22 additions & 1 deletion pkg/downloader/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package downloader
import (
hclConfigs "github.com/hashicorp/terraform/configs"
"github.com/hashicorp/terraform/registry/regsrc"
"github.com/hashicorp/terraform/registry/response"
)

// Downloader helps in downloading different kinds of modules from
Expand All @@ -28,12 +29,32 @@ type Downloader interface {
DownloadWithType(remoteType, url, dest string) (finalDir string, err error)
GetURLSubDir(url, dest string) (urlWithType string, subDir string, err error)
SubDirGlob(string, string) (string, error)
}

// ModuleDownloader helps in downloading the remote modules
type ModuleDownloader interface {
DownloadModule(addr, destPath string) (string, error)
DownloadRemoteModule(requiredVersion hclConfigs.VersionConstraint, destPath string, module *regsrc.Module) (string, error)
CleanUp()
}

// terraformRegistryClient will help interact with terraform registries
type terraformRegistryClient interface {
ModuleVersions(module *regsrc.Module) (*response.ModuleVersions, error)
ModuleLocation(module *regsrc.Module, version string) (string, error)
}

// NewDownloader returns a new downloader
func NewDownloader() Downloader {
return NewGoGetter()
return newGoGetter()
}

// NewRemoteDownloader returns a new ModuleDownloader
func NewRemoteDownloader() ModuleDownloader {
return newRemoteModuleInstaller()
}

// newClientRegistry returns a terraformClientRegistry to query terraform registries
func newClientRegistry() terraformRegistryClient {
return newTerraformRegistryClient()
}
47 changes: 30 additions & 17 deletions pkg/downloader/module-download.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,47 @@ import (
"go.uber.org/zap"
)

// newRemoteModuleInstaller returns a RemoteModuleInstaller initialized with a
// new cache, downloader and terraform registry client
func newRemoteModuleInstaller() *remoteModuleInstaller {
return &remoteModuleInstaller{
cache: make(map[string]string),
downloader: NewDownloader(),
terraformRegistryClient: newClientRegistry(),
}
}

// newTerraformRegistryClient returns a client to query terraform registries
func newTerraformRegistryClient() terraformRegistryClient {
// get terraform registry client.
// terraform registry client provides methods for querying the terraform module registry
return registry.NewClient(nil, nil)
}

// DownloadModule retrieves the package referenced in the given address
// into the installation path and then returns the full path to any subdir
// indicated in the address.
func (g *GoGetter) DownloadModule(addr, destPath string) (string, error) {
func (r *remoteModuleInstaller) DownloadModule(addr, destPath string) (string, error) {

// split url and subdir
URLWithType, subDir, err := g.GetURLSubDir(addr, destPath)
URLWithType, subDir, err := r.downloader.GetURLSubDir(addr, destPath)
if err != nil {
return "", err
}

// check if the module has already been downloaded
if prevDir, exists := g.cache[URLWithType]; exists {
if prevDir, exists := r.cache[URLWithType]; exists {
zap.S().Debugf("module %q already installed at %q", URLWithType, prevDir)
destPath = prevDir
} else {
destPath, err := g.Download(URLWithType, destPath)
destPath, err := r.downloader.Download(URLWithType, destPath)
if err != nil {
zap.S().Debugf("failed to download remote module. error: '%v'", err)
return "", err
}
// Remember where we installed this so we might reuse this directory
// on subsequent calls to avoid re-downloading.
g.cache[URLWithType] = destPath
r.cache[URLWithType] = destPath
}

// Our subDir string can contain wildcards until this point, so that
Expand All @@ -61,7 +78,7 @@ func (g *GoGetter) DownloadModule(addr, destPath string) (string, error) {
// resolve that into a concrete path.
var finalDir string
if subDir != "" {
finalDir, err = g.SubDirGlob(destPath, subDir)
finalDir, err = r.downloader.SubDirGlob(destPath, subDir)
if err != nil {
return "", err
}
Expand All @@ -77,7 +94,7 @@ func (g *GoGetter) DownloadModule(addr, destPath string) (string, error) {

// DownloadRemoteModule will download remote modules from public and private terraform registries
// this function takes similar approach taken by terraform init for downloading terraform registry modules
func (g *GoGetter) DownloadRemoteModule(requiredVersion hclConfigs.VersionConstraint, destPath string, module *regsrc.Module) (string, error) {
func (r *remoteModuleInstaller) DownloadRemoteModule(requiredVersion hclConfigs.VersionConstraint, destPath string, module *regsrc.Module) (string, error) {
// Terraform doesn't allow the hostname to contain Punycode
// module.SvcHost returns an error for such case
_, err := module.SvcHost()
Expand All @@ -86,12 +103,8 @@ func (g *GoGetter) DownloadRemoteModule(requiredVersion hclConfigs.VersionConstr
return "", err
}

// get terraform registry client.
// terraform registry client provides methods for querying the terraform module registry
regClient := registry.NewClient(nil, nil)

// get all the available module versions from the terraform registry
moduleVersions, err := regClient.ModuleVersions(module)
moduleVersions, err := r.terraformRegistryClient.ModuleVersions(module)
if err != nil {
if registry.IsModuleNotFound(err) {
zap.S().Errorf("module: %s, not be found at registry: %s", module.String(), module.Host().Display())
Expand All @@ -109,16 +122,16 @@ func (g *GoGetter) DownloadRemoteModule(requiredVersion hclConfigs.VersionConstr
}

// get the source location for the matched version
sourceLocation, err := regClient.ModuleLocation(module, versionToDownload.String())
sourceLocation, err := r.terraformRegistryClient.ModuleLocation(module, versionToDownload.String())
if err != nil {
zap.S().Errorf("error while getting the source location for module: %s, at registry: %s", module.String(), module.Host().Display())
return "", err
}

downloadLocation, err := g.DownloadModule(sourceLocation, destPath)
downloadLocation, err := r.DownloadModule(sourceLocation, destPath)
if err != nil {
zap.S().Errorf("error while downloading module: %s, with source location: %s", module.String(), sourceLocation)
return "", nil
return "", err
}

if module.RawSubmodule != "" {
Expand All @@ -130,8 +143,8 @@ func (g *GoGetter) DownloadRemoteModule(requiredVersion hclConfigs.VersionConstr
}

// CleanUp cleans up all the locally downloaded modules
func (g *GoGetter) CleanUp() {
for url, path := range g.cache {
func (r *remoteModuleInstaller) CleanUp() {
for url, path := range r.cache {
zap.S().Debugf("deleting %q installed at %q", url, path)
os.RemoveAll(path)
}
Expand Down
Loading

0 comments on commit 90fe2cb

Please sign in to comment.