diff --git a/cmd/plugin_stage.go b/cmd/plugin_stage.go index 22cc2b24a..88668afe2 100644 --- a/cmd/plugin_stage.go +++ b/cmd/plugin_stage.go @@ -21,6 +21,7 @@ package main import ( "context" "fmt" + "io" "net/url" "os" "path" @@ -89,34 +90,18 @@ func stagePluginMain(cmd *cobra.Command, args []string) { ctx := cmd.Context() originPrefixStr := param.StagePlugin_OriginPrefix.GetString() - if len(originPrefixStr) == 0 { - log.Errorln("Origin prefix not specified; must be a URL (osdf://...)") - os.Exit(1) - } - originPrefixUri, err := url.Parse(originPrefixStr) + mountPrefixStr := param.StagePlugin_MountPrefix.GetString() + shadowOriginPrefixStr := param.StagePlugin_ShadowOriginPrefix.GetString() + + originPrefixUri, err := validatePrefixes(originPrefixStr, mountPrefixStr, shadowOriginPrefixStr) if err != nil { - log.Errorln("Origin prefix must be a URL (osdf://...):", err) - os.Exit(1) - } - if originPrefixUri.Scheme != "osdf" { - log.Errorln("Origin prefix scheme must be osdf://:", originPrefixUri.Scheme) + log.Errorln("Problem validating provided prefixes:", err) os.Exit(1) } + originPrefixPath := path.Clean("/" + originPrefixUri.Host + "/" + originPrefixUri.Path) log.Debugln("Local origin prefix:", originPrefixPath) - mountPrefixStr := param.StagePlugin_MountPrefix.GetString() - if len(mountPrefixStr) == 0 { - log.Errorln("Mount prefix is required; must be a local path (/mnt/foo/...)") - os.Exit(1) - } - - shadowOriginPrefixStr := param.StagePlugin_ShadowOriginPrefix.GetString() - if len(shadowOriginPrefixStr) == 0 { - log.Errorln("Shadow origin prefix is required; must be a URL (osdf://....)") - os.Exit(1) - } - tokenLocation := param.Plugin_Token.GetString() pb := newProgressBar() @@ -128,52 +113,12 @@ func stagePluginMain(cmd *cobra.Command, args []string) { pb.launchDisplay(ctx) } - var sources []string - var extraSources []string isHook := param.StagePlugin_Hook.GetBool() - if isHook { - buffer := make([]byte, 100*1024) - bytesread, err := os.Stdin.Read(buffer) - if err != nil { - log.Errorln("Failed to read ClassAd from stdin:", err) - os.Exit(1) - } - classad, err := classads.ParseShadowClassAd(string(buffer[:bytesread])) - if err != nil { - log.Errorln("Failed to parse ClassAd from stdin: ", err) - os.Exit(1) - } - inputList, err := classad.Get("TransferInput") - if err != nil || inputList == nil { - // No TransferInput, no need to transform... - log.Debugln("No transfer input found in classad, no need to transform.") - os.Exit(0) - } - inputListStr, ok := inputList.(string) - if !ok { - log.Errorln("TransferInput is not a string") - os.Exit(1) - } - re := regexp.MustCompile(`[,\s]+`) - for _, source := range re.Split(inputListStr, -1) { - log.Debugln("Examining transfer input file", source) - if strings.HasPrefix(source, mountPrefixStr) { - sources = append(sources, source) - } else { - // Replace the osdf:// prefix with the local mount path - source_uri, err := url.Parse(source) - source_uri_scheme := strings.SplitN(source_uri.Scheme, "+", 2)[0] - if err == nil && source_uri_scheme == "osdf" { - source_path := path.Clean("/" + source_uri.Host + "/" + source_uri.Path) - if strings.HasPrefix(source_path, originPrefixPath) { - sources = append(sources, mountPrefixStr+source_path[len(originPrefixPath):]) - continue - } - } - extraSources = append(extraSources, source) - } - } - } else { + var sources, extraSources []string + var exitCode int + + // If not a condor hook, our souces come from our args + if !isHook { log.Debugln("Len of source:", len(args)) if len(args) < 1 { log.Errorln("No ingest sources") @@ -183,11 +128,39 @@ func stagePluginMain(cmd *cobra.Command, args []string) { os.Exit(1) } sources = args + log.Debugln("Sources:", sources) + } else { // Otherwise, parse the classad for our sources + // We pass in stdin here because that is how we get the classad + sources, extraSources, err, exitCode = processTransferInput(os.Stdin, mountPrefixStr, originPrefixPath) + if err != nil { + log.Errorln("Failure to get sources from job's classad:", err) + os.Exit(exitCode) + } } - log.Debugln("Sources:", sources) var result error var xformSources []string + + xformSources, result = doPluginStaging(sources, extraSources, mountPrefixStr, shadowOriginPrefixStr, tokenLocation) + // Exit with failure + if result != nil { + // Print the list of errors + log.Errorln("Failure in staging files:", result) + if client.ShouldRetry(result) { + log.Errorln("Errors are retryable") + os.Exit(11) + } + os.Exit(1) + } + // If we are a condor hook, we need to print the classad change out. Condor will notice it and handle the rest + if isHook { + printOutput(xformSources, extraSources) + } +} + +// This function performs the actual "staging" on the specified shadow origin +func doPluginStaging(sources []string, extraSources []string, mountPrefixStr, shadowOriginPrefixStr, tokenLocation string) (xformSources []string, result error) { + for _, src := range sources { _, newSource, result := client.DoShadowIngest(context.Background(), src, mountPrefixStr, shadowOriginPrefixStr, client.WithTokenLocation(tokenLocation), client.WithAcquireToken(false)) if result != nil { @@ -203,23 +176,85 @@ func stagePluginMain(cmd *cobra.Command, args []string) { xformSources = append(xformSources, newSource) } - // Exit with failure - if result != nil { - // Print the list of errors - log.Errorln("Failure in staging files:", result) - if client.ShouldRetry(result) { - log.Errorln("Errors are retryable") - os.Exit(11) - } - os.Exit(1) + return xformSources, result +} + +// This function is used to print our changes out in the case we are a condor hook +func printOutput(xformSources []string, extraSources []string) { + inputsStr := strings.Join(extraSources, ", ") + if len(extraSources) > 0 && len(xformSources) > 0 { + inputsStr = inputsStr + ", " + strings.Join(xformSources, ", ") + } else if len(xformSources) > 0 { + inputsStr = strings.Join(xformSources, ", ") } - if isHook { - inputsStr := strings.Join(extraSources, ", ") - if len(extraSources) > 0 && len(xformSources) > 0 { - inputsStr = inputsStr + ", " + strings.Join(xformSources, ", ") - } else if len(xformSources) > 0 { - inputsStr = strings.Join(xformSources, ", ") + fmt.Printf("TransferInput = \"%s\"", inputsStr) +} + +// This function is utilized to validate the arguments passed in to ensure they exist and are in the correct format +func validatePrefixes(originPrefixStr string, mountPrefixStr string, shadowOriginPrefixStr string) (originPrefixUri *url.URL, err error) { + if len(originPrefixStr) == 0 { + return nil, fmt.Errorf("Origin prefix not specified; must be a URL (osdf://...)") + } + + originPrefixUri, err = url.Parse(originPrefixStr) + if err != nil { + return nil, fmt.Errorf("Origin prefix must be a URL (osdf://...): %v", err) + } + if originPrefixUri.Scheme != "osdf" { + return nil, fmt.Errorf("Origin prefix scheme must be osdf://: %s", originPrefixUri.Scheme) + } + + if len(mountPrefixStr) == 0 { + return nil, fmt.Errorf("Mount prefix is required; must be a local path (/mnt/foo/...)") + } + if len(shadowOriginPrefixStr) == 0 { + return nil, fmt.Errorf("Shadow origin prefix is required; must be a URL (osdf://....)") + } + + return originPrefixUri, nil +} + +// This function is used when we are using a condor hook and need to get our sources from the "TransferInput" classad +// We return our sources, any extra sources, an err, and the exit code (since we have a case to exit 0) +// Note: we pass in a reader for testability but the main function will always pass stdin to get the classad +func processTransferInput(reader io.Reader, mountPrefixStr string, originPrefixPath string) (sources []string, extraSources []string, err error, exitCode int) { + buffer := make([]byte, 100*1024) + bytesread, err := reader.Read(buffer) + if err != nil { + return nil, nil, fmt.Errorf("Failed to read ClassAd from stdin: %v", err), 1 + } + classad, err := classads.ParseShadowClassAd(string(buffer[:bytesread])) + if err != nil { + return nil, nil, fmt.Errorf("Failed to parse ClassAd from stdin: %v", err), 1 + } + inputList, err := classad.Get("TransferInput") + if err != nil || inputList == nil { + // No TransferInput, no need to transform therefore we exit(0) + return nil, nil, fmt.Errorf("No transfer input found in classad, no need to transform."), 0 + } + inputListStr, ok := inputList.(string) + if !ok { + return nil, nil, fmt.Errorf("TransferInput is not a string"), 1 + } + re := regexp.MustCompile(`[,\s]+`) + for _, source := range re.Split(inputListStr, -1) { + log.Debugln("Examining transfer input file", source) + if strings.HasPrefix(source, mountPrefixStr) { + sources = append(sources, source) + } else { + // Replace the osdf:// prefix with the local mount path + source_uri, err := url.Parse(source) + source_uri_scheme := strings.SplitN(source_uri.Scheme, "+", 2)[0] + if err == nil && source_uri_scheme == "osdf" { + source_path := path.Clean("/" + source_uri.Host + "/" + source_uri.Path) + if strings.HasPrefix(source_path, originPrefixPath) { + sources = append(sources, mountPrefixStr+source_path[len(originPrefixPath):]) + continue + } + } + extraSources = append(extraSources, source) } - fmt.Printf("TransferInput = \"%s\"", inputsStr) } + log.Debugln("Sources:", sources) + return sources, extraSources, nil, 0 }