Skip to content

Commit

Permalink
Refactor plugin staging to make it more testable
Browse files Browse the repository at this point in the history
Separates out a lot of pieces of the main function into smaller
functions. This helps with readability as well as for future unit
testing.
  • Loading branch information
joereuss12 committed Mar 11, 2024
1 parent 7e5c9e1 commit d85e199
Showing 1 changed file with 119 additions and 84 deletions.
203 changes: 119 additions & 84 deletions cmd/plugin_stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package main
import (
"context"
"fmt"
"io"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -89,34 +90,18 @@ func stagePluginMain(cmd *cobra.Command, args []string) {
ctx := cmd.Context()

originPrefixStr := param.StagePlugin_OriginPrefix.GetString()
if len(originPrefixStr) == 0 {
log.Errorln("Origin prefix not specified; must be a URL (osdf://...)")
os.Exit(1)
}
originPrefixUri, err := url.Parse(originPrefixStr)
mountPrefixStr := param.StagePlugin_MountPrefix.GetString()
shadowOriginPrefixStr := param.StagePlugin_ShadowOriginPrefix.GetString()

originPrefixUri, err := validatePrefixes(originPrefixStr, mountPrefixStr, shadowOriginPrefixStr)
if err != nil {
log.Errorln("Origin prefix must be a URL (osdf://...):", err)
os.Exit(1)
}
if originPrefixUri.Scheme != "osdf" {
log.Errorln("Origin prefix scheme must be osdf://:", originPrefixUri.Scheme)
log.Errorln("Problem validating provided prefixes:", err)
os.Exit(1)
}

originPrefixPath := path.Clean("/" + originPrefixUri.Host + "/" + originPrefixUri.Path)
log.Debugln("Local origin prefix:", originPrefixPath)

mountPrefixStr := param.StagePlugin_MountPrefix.GetString()
if len(mountPrefixStr) == 0 {
log.Errorln("Mount prefix is required; must be a local path (/mnt/foo/...)")
os.Exit(1)
}

shadowOriginPrefixStr := param.StagePlugin_ShadowOriginPrefix.GetString()
if len(shadowOriginPrefixStr) == 0 {
log.Errorln("Shadow origin prefix is required; must be a URL (osdf://....)")
os.Exit(1)
}

tokenLocation := param.Plugin_Token.GetString()

pb := newProgressBar()
Expand All @@ -128,52 +113,12 @@ func stagePluginMain(cmd *cobra.Command, args []string) {
pb.launchDisplay(ctx)
}

var sources []string
var extraSources []string
isHook := param.StagePlugin_Hook.GetBool()
if isHook {
buffer := make([]byte, 100*1024)
bytesread, err := os.Stdin.Read(buffer)
if err != nil {
log.Errorln("Failed to read ClassAd from stdin:", err)
os.Exit(1)
}
classad, err := classads.ParseShadowClassAd(string(buffer[:bytesread]))
if err != nil {
log.Errorln("Failed to parse ClassAd from stdin: ", err)
os.Exit(1)
}
inputList, err := classad.Get("TransferInput")
if err != nil || inputList == nil {
// No TransferInput, no need to transform...
log.Debugln("No transfer input found in classad, no need to transform.")
os.Exit(0)
}
inputListStr, ok := inputList.(string)
if !ok {
log.Errorln("TransferInput is not a string")
os.Exit(1)
}
re := regexp.MustCompile(`[,\s]+`)
for _, source := range re.Split(inputListStr, -1) {
log.Debugln("Examining transfer input file", source)
if strings.HasPrefix(source, mountPrefixStr) {
sources = append(sources, source)
} else {
// Replace the osdf:// prefix with the local mount path
source_uri, err := url.Parse(source)
source_uri_scheme := strings.SplitN(source_uri.Scheme, "+", 2)[0]
if err == nil && source_uri_scheme == "osdf" {
source_path := path.Clean("/" + source_uri.Host + "/" + source_uri.Path)
if strings.HasPrefix(source_path, originPrefixPath) {
sources = append(sources, mountPrefixStr+source_path[len(originPrefixPath):])
continue
}
}
extraSources = append(extraSources, source)
}
}
} else {
var sources, extraSources []string
var exitCode int

// If not a condor hook, our souces come from our args
if !isHook {
log.Debugln("Len of source:", len(args))
if len(args) < 1 {
log.Errorln("No ingest sources")
Expand All @@ -183,11 +128,39 @@ func stagePluginMain(cmd *cobra.Command, args []string) {
os.Exit(1)
}
sources = args
log.Debugln("Sources:", sources)
} else { // Otherwise, parse the classad for our sources
// We pass in stdin here because that is how we get the classad
sources, extraSources, err, exitCode = processTransferInput(os.Stdin, mountPrefixStr, originPrefixPath)
if err != nil {
log.Errorln("Failure to get sources from job's classad:", err)
os.Exit(exitCode)
}
}
log.Debugln("Sources:", sources)

var result error
var xformSources []string

xformSources, result = doPluginStaging(sources, extraSources, mountPrefixStr, shadowOriginPrefixStr, tokenLocation)
// Exit with failure
if result != nil {
// Print the list of errors
log.Errorln("Failure in staging files:", result)
if client.ShouldRetry(result) {
log.Errorln("Errors are retryable")
os.Exit(11)
}
os.Exit(1)
}
// If we are a condor hook, we need to print the classad change out. Condor will notice it and handle the rest
if isHook {
printOutput(xformSources, extraSources)
}
}

// This function performs the actual "staging" on the specified shadow origin
func doPluginStaging(sources []string, extraSources []string, mountPrefixStr, shadowOriginPrefixStr, tokenLocation string) (xformSources []string, result error) {

for _, src := range sources {
_, newSource, result := client.DoShadowIngest(context.Background(), src, mountPrefixStr, shadowOriginPrefixStr, client.WithTokenLocation(tokenLocation), client.WithAcquireToken(false))
if result != nil {
Expand All @@ -203,23 +176,85 @@ func stagePluginMain(cmd *cobra.Command, args []string) {
xformSources = append(xformSources, newSource)
}

// Exit with failure
if result != nil {
// Print the list of errors
log.Errorln("Failure in staging files:", result)
if client.ShouldRetry(result) {
log.Errorln("Errors are retryable")
os.Exit(11)
}
os.Exit(1)
return xformSources, result
}

// This function is used to print our changes out in the case we are a condor hook
func printOutput(xformSources []string, extraSources []string) {
inputsStr := strings.Join(extraSources, ", ")
if len(extraSources) > 0 && len(xformSources) > 0 {
inputsStr = inputsStr + ", " + strings.Join(xformSources, ", ")
} else if len(xformSources) > 0 {
inputsStr = strings.Join(xformSources, ", ")
}
if isHook {
inputsStr := strings.Join(extraSources, ", ")
if len(extraSources) > 0 && len(xformSources) > 0 {
inputsStr = inputsStr + ", " + strings.Join(xformSources, ", ")
} else if len(xformSources) > 0 {
inputsStr = strings.Join(xformSources, ", ")
fmt.Printf("TransferInput = \"%s\"", inputsStr)
}

// This function is utilized to validate the arguments passed in to ensure they exist and are in the correct format
func validatePrefixes(originPrefixStr string, mountPrefixStr string, shadowOriginPrefixStr string) (originPrefixUri *url.URL, err error) {
if len(originPrefixStr) == 0 {
return nil, fmt.Errorf("Origin prefix not specified; must be a URL (osdf://...)")
}

originPrefixUri, err = url.Parse(originPrefixStr)
if err != nil {
return nil, fmt.Errorf("Origin prefix must be a URL (osdf://...): %v", err)
}
if originPrefixUri.Scheme != "osdf" {
return nil, fmt.Errorf("Origin prefix scheme must be osdf://: %s", originPrefixUri.Scheme)
}

if len(mountPrefixStr) == 0 {
return nil, fmt.Errorf("Mount prefix is required; must be a local path (/mnt/foo/...)")
}
if len(shadowOriginPrefixStr) == 0 {
return nil, fmt.Errorf("Shadow origin prefix is required; must be a URL (osdf://....)")
}

return originPrefixUri, nil
}

// This function is used when we are using a condor hook and need to get our sources from the "TransferInput" classad
// We return our sources, any extra sources, an err, and the exit code (since we have a case to exit 0)
// Note: we pass in a reader for testability but the main function will always pass stdin to get the classad
func processTransferInput(reader io.Reader, mountPrefixStr string, originPrefixPath string) (sources []string, extraSources []string, err error, exitCode int) {
buffer := make([]byte, 100*1024)
bytesread, err := reader.Read(buffer)
if err != nil {
return nil, nil, fmt.Errorf("Failed to read ClassAd from stdin: %v", err), 1
}
classad, err := classads.ParseShadowClassAd(string(buffer[:bytesread]))
if err != nil {
return nil, nil, fmt.Errorf("Failed to parse ClassAd from stdin: %v", err), 1
}
inputList, err := classad.Get("TransferInput")
if err != nil || inputList == nil {
// No TransferInput, no need to transform therefore we exit(0)
return nil, nil, fmt.Errorf("No transfer input found in classad, no need to transform."), 0
}
inputListStr, ok := inputList.(string)
if !ok {
return nil, nil, fmt.Errorf("TransferInput is not a string"), 1
}
re := regexp.MustCompile(`[,\s]+`)
for _, source := range re.Split(inputListStr, -1) {
log.Debugln("Examining transfer input file", source)
if strings.HasPrefix(source, mountPrefixStr) {
sources = append(sources, source)
} else {
// Replace the osdf:// prefix with the local mount path
source_uri, err := url.Parse(source)
source_uri_scheme := strings.SplitN(source_uri.Scheme, "+", 2)[0]
if err == nil && source_uri_scheme == "osdf" {
source_path := path.Clean("/" + source_uri.Host + "/" + source_uri.Path)
if strings.HasPrefix(source_path, originPrefixPath) {
sources = append(sources, mountPrefixStr+source_path[len(originPrefixPath):])
continue
}
}
extraSources = append(extraSources, source)
}
fmt.Printf("TransferInput = \"%s\"", inputsStr)
}
log.Debugln("Sources:", sources)
return sources, extraSources, nil, 0
}

0 comments on commit d85e199

Please sign in to comment.