diff --git a/client/acquire_token.go b/client/acquire_token.go index 88756b71a..27eea9ebf 100644 --- a/client/acquire_token.go +++ b/client/acquire_token.go @@ -34,7 +34,7 @@ import ( oauth2_upstream "golang.org/x/oauth2" ) -func TokenIsAcceptable(jwtSerialized string, osdfPath string, namespace namespaces.Namespace, isWrite bool) bool { +func TokenIsAcceptable(jwtSerialized string, osdfPath string, namespace namespaces.Namespace, opts config.TokenGenerationOpts) bool { parser := jwt.Parser{SkipClaimsValidation: true} token, _, err := parser.ParseUnverified(jwtSerialized, &jwt.MapClaims{}) if err != nil { @@ -71,7 +71,7 @@ func TokenIsAcceptable(jwtSerialized string, osdfPath string, namespace namespac for _, scope := range strings.Split(scopes, " ") { scope_info := strings.Split(scope, ":") scopeOK := false - if isWrite && (scope_info[0] == "storage.modify" || scope_info[0] == "storage.create") { + if (opts.Operation == config.TokenWrite || opts.Operation == config.TokenSharedWrite) && (scope_info[0] == "storage.modify" || scope_info[0] == "storage.create") { scopeOK = true } else if scope_info[0] == "storage.read" { scopeOK = true @@ -84,7 +84,9 @@ func TokenIsAcceptable(jwtSerialized string, osdfPath string, namespace namespac acceptableScope = true break } - if strings.HasPrefix(targetResource, scope_info[1]) { + // Shared URLs must have exact matches; otherwise, prefix matching is acceptable. + if ((opts.Operation == config.TokenSharedWrite || opts.Operation == config.TokenSharedRead) && (targetResource == scope_info[1])) || + strings.HasPrefix(targetResource, scope_info[1]) { acceptableScope = true break } @@ -142,7 +144,7 @@ func RegisterClient(namespace namespaces.Namespace) (*config.PrefixEntry, error) // Given a URL and a piece of the namespace, attempt to acquire a valid // token for that URL. -func AcquireToken(destination *url.URL, namespace namespaces.Namespace, isWrite bool) (string, error) { +func AcquireToken(destination *url.URL, namespace namespaces.Namespace, opts config.TokenGenerationOpts) (string, error) { log.Debugln("Acquiring a token from configuration and OAuth2") if namespace.CredentialGen == nil || namespace.CredentialGen.Strategy == nil { @@ -208,7 +210,7 @@ func AcquireToken(destination *url.URL, namespace namespaces.Namespace, isWrite var acceptableToken *config.TokenEntry = nil acceptableUnexpiredToken := "" for idx, token := range prefixEntry.Tokens { - if !TokenIsAcceptable(token.AccessToken, destination.Path, namespace, isWrite) { + if !TokenIsAcceptable(token.AccessToken, destination.Path, namespace, opts) { continue } if acceptableToken == nil { @@ -262,7 +264,7 @@ func AcquireToken(destination *url.URL, namespace namespaces.Namespace, isWrite } } - token, err := oauth2.AcquireToken(issuer, prefixEntry, namespace.CredentialGen, destination.Path, isWrite) + token, err := oauth2.AcquireToken(issuer, prefixEntry, namespace.CredentialGen, destination.Path, opts) if err != nil { return "", err } diff --git a/client/director.go b/client/director.go index 54d9d7a4f..8fac86e13 100644 --- a/client/director.go +++ b/client/director.go @@ -63,10 +63,11 @@ func HeaderParser(values string) (retMap map[string]string) { // Given the Director response, create the ordered list of caches // and store it as namespace.SortedDirectorCaches -func CreateNsFromDirectorResp(dirResp *http.Response, namespace *namespaces.Namespace) (err error) { +func CreateNsFromDirectorResp(dirResp *http.Response) (namespace namespaces.Namespace, err error) { pelicanNamespaceHdr := dirResp.Header.Values("X-Pelican-Namespace") if len(pelicanNamespaceHdr) == 0 { - return errors.New("Pelican director did not include mandatory X-Pelican-Namespace header in response") + err = errors.New("Pelican director did not include mandatory X-Pelican-Namespace header in response") + return } xPelicanNamespace := HeaderParser(pelicanNamespaceHdr[0]) namespace.Path = xPelicanNamespace["namespace"] @@ -175,6 +176,9 @@ func QueryDirector(source string, directorUrl string) (resp *http.Response, err func GetCachesFromDirectorResponse(resp *http.Response, needsToken bool) (caches []namespaces.DirectorCache, err error) { // Get the Link header linkHeader := resp.Header.Values("Link") + if len(linkHeader) == 0 { + return []namespaces.DirectorCache{}, nil + } for _, linksStr := range strings.Split(linkHeader[0], ",") { links := strings.Split(strings.ReplaceAll(linksStr, " ", ""), ";") diff --git a/client/director_test.go b/client/director_test.go index b82b4eb17..0a2fc255c 100644 --- a/client/director_test.go +++ b/client/director_test.go @@ -109,8 +109,7 @@ func TestCreateNsFromDirectorResp(t *testing.T) { } // Call the function in question - var ns namespaces.Namespace - err := CreateNsFromDirectorResp(directorResponse, &ns) + ns, err := CreateNsFromDirectorResp(directorResponse) // Test for expected outputs assert.NoError(t, err, "Error creating Namespace from Director response") diff --git a/client/main.go b/client/main.go index 952eea244..46ad834f1 100644 --- a/client/main.go +++ b/client/main.go @@ -182,7 +182,11 @@ func getToken(destination *url.URL, namespace namespaces.Namespace, isWrite bool if token_location == "" { if !ObjectClientOptions.Plugin { - value, err := AcquireToken(destination, namespace, isWrite) + opts := config.TokenGenerationOpts{Operation: config.TokenSharedRead} + if isWrite { + opts.Operation = config.TokenSharedWrite + } + value, err := AcquireToken(destination, namespace, opts) if err == nil { return value, nil } @@ -505,7 +509,7 @@ func DoStashCPSingle(sourceFile string, destination string, methods []string, re AddError(err) return 0, err } - err = CreateNsFromDirectorResp(dirResp, &ns) + ns, err = CreateNsFromDirectorResp(dirResp) if err != nil { AddError(err) return 0, err diff --git a/client/sharing_url.go b/client/sharing_url.go new file mode 100644 index 000000000..d13a3bdaa --- /dev/null +++ b/client/sharing_url.go @@ -0,0 +1,106 @@ +/*************************************************************** + * + * Copyright (C) 2023, University of Nebraska-Lincoln + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package client + +import ( + "net/url" + "strings" + + "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/param" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" +) + +func getDirectorFromUrl(objectUrl *url.URL) (string, error) { + configDirectorUrl := param.Federation_DirectorUrl.GetString() + var directorUrl string + if objectUrl.Scheme == "pelican" { + if objectUrl.Host == "" { + if configDirectorUrl == "" { + return "", errors.New("Must specify (or configure) the federation hostname with the pelican://-style URLs") + } + directorUrl = configDirectorUrl + } else { + discoveryUrl := url.URL{ + Scheme: "https", + Host: objectUrl.Host, + } + viper.Set("Federation.DirectorUrl", "") + viper.Set("Federation.DiscoveryUrl", discoveryUrl.String()) + if err := config.DiscoverFederation(); err != nil { + return "", errors.Wrapf(err, "Failed to discover location of the director for the federation %s", objectUrl.Host) + } + if directorUrl = param.Federation_DirectorUrl.GetString(); directorUrl == "" { + return "", errors.Errorf("Director for the federation %s not discovered", objectUrl.Host) + } + } + } else if objectUrl.Scheme == "osdf" && configDirectorUrl == "" { + if objectUrl.Host != "" { + objectUrl.Path = "/" + objectUrl.Host + objectUrl.Path + objectUrl.Host = "" + } + viper.Set("Federation.DiscoveryUrl", "https://osg-htc.org") + if err := config.DiscoverFederation(); err != nil { + return "", errors.Wrap(err, "Failed to discover director for the OSDF") + } + if directorUrl = param.Federation_DirectorUrl.GetString(); directorUrl == "" { + return "", errors.Errorf("Director for the OSDF not discovered") + } + } else if objectUrl.Scheme == "" { + if configDirectorUrl == "" { + return "", errors.Errorf("Must provide a federation name for path %s (e.g., pelican://osg-htc.org/%s)", objectUrl.Path, objectUrl.Path) + } else { + directorUrl = configDirectorUrl + } + } else if objectUrl.Scheme != "osdf" { + return "", errors.Errorf("Unsupported scheme for pelican: %s://", objectUrl.Scheme) + } + return directorUrl, nil +} + +func CreateSharingUrl(objectUrl *url.URL, isWrite bool) (string, error) { + directorUrl, err := getDirectorFromUrl(objectUrl) + if err != nil { + return "", err + } + objectUrl.Path = "/" + strings.TrimPrefix(objectUrl.Path, "/") + + log.Debugln("Will query director for path", objectUrl.Path) + dirResp, err := QueryDirector(objectUrl.Path, directorUrl) + if err != nil { + log.Errorln("Error while querying the Director:", err) + return "", errors.Wrapf(err, "Error while querying the director at %s", directorUrl) + } + namespace, err := CreateNsFromDirectorResp(dirResp) + if err != nil { + return "", errors.Wrapf(err, "Unable to parse response from director at %s", directorUrl) + } + + opts := config.TokenGenerationOpts{Operation: config.TokenSharedRead} + if isWrite { + opts.Operation = config.TokenSharedWrite + } + token, err := AcquireToken(objectUrl, namespace, opts) + if err != nil { + err = errors.Wrap(err, "Failed to acquire token") + } + return token, err +} diff --git a/client/sharing_url_test.go b/client/sharing_url_test.go new file mode 100644 index 000000000..93aa1e4cf --- /dev/null +++ b/client/sharing_url_test.go @@ -0,0 +1,170 @@ +/*************************************************************** + * + * Copyright (C) 2023, University of Nebraska-Lincoln + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package client + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + + "github.com/pelicanplatform/pelican/config" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDirectorGeneration(t *testing.T) { + returnError := false + returnErrorRef := &returnError + + handler := func(w http.ResponseWriter, r *http.Request) { + discoveryConfig := `{"director_endpoint": "https://location.example.com", "namespace_registration_endpoint": "https://location.example.com/namespace", "jwks_uri": "https://location.example.com/jwks"}` + if *returnErrorRef { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(discoveryConfig)) + assert.NoError(t, err) + } + } + server := httptest.NewTLSServer(http.HandlerFunc(handler)) + defer server.Close() + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + objectUrl := url.URL{ + Scheme: "pelican", + Host: serverURL.Host, + Path: "/test/foo", + } + + // Discovery works to get URL + viper.Reset() + viper.Set("TLSSkipVerify", true) + err = config.InitClient() + require.NoError(t, err) + dUrl, err := getDirectorFromUrl(&objectUrl) + require.NoError(t, err) + assert.Equal(t, dUrl, "https://location.example.com") + + // Discovery URL overrides the federation config. + viper.Reset() + viper.Set("TLSSkipVerify", true) + viper.Set("Federation.DirectorURL", "https://location2.example.com") + dUrl, err = getDirectorFromUrl(&objectUrl) + require.NoError(t, err) + assert.Equal(t, dUrl, "https://location.example.com") + + // Fallback to configuration if no discovery present + viper.Reset() + viper.Set("Federation.DirectorURL", "https://location2.example.com") + objectUrl.Host = "" + dUrl, err = getDirectorFromUrl(&objectUrl) + require.NoError(t, err) + assert.Equal(t, dUrl, "https://location2.example.com") + + // Error if server has an error + viper.Reset() + returnError = true + viper.Set("TLSSkipVerify", true) + objectUrl.Host = serverURL.Host + _, err = getDirectorFromUrl(&objectUrl) + require.Error(t, err) + + // Error if neither config nor hostname provided. + viper.Reset() + objectUrl.Host = "" + _, err = getDirectorFromUrl(&objectUrl) + require.Error(t, err) + + // Error on unknown scheme + viper.Reset() + objectUrl.Scheme = "buzzard" + _, err = getDirectorFromUrl(&objectUrl) + require.Error(t, err) +} + +func TestSharingUrl(t *testing.T) { + // Construct a local server that we can poke with QueryDirector + myUrl := "http://redirect.com" + myUrlRef := &myUrl + log.SetLevel(log.DebugLevel) + handler := func(w http.ResponseWriter, r *http.Request) { + issuerLoc := *myUrlRef + "/issuer" + + if strings.HasPrefix(r.URL.Path, "/test") { + w.Header().Set("Location", *myUrlRef) + w.Header().Set("X-Pelican-Namespace", "namespace=/test, require-token=true") + w.Header().Set("X-Pelican-Authorization", fmt.Sprintf("issuer=%s", issuerLoc)) + w.Header().Set("X-Pelican-Token-Generation", fmt.Sprintf("issuer=%s, base-path=/test, strategy=OAuth2", issuerLoc)) + w.WriteHeader(http.StatusTemporaryRedirect) + } else if r.URL.Path == "/issuer/.well-known/openid-configuration" { + w.WriteHeader(http.StatusOK) + oidcConfig := fmt.Sprintf(`{"token_endpoint": "%s/token", "registration_endpoint": "%s/register", "grant_types_supported": ["urn:ietf:params:oauth:grant-type:device_code"], "device_authorization_endpoint": "%s/device_authz"}`, issuerLoc, issuerLoc, issuerLoc) + _, err := w.Write([]byte(oidcConfig)) + assert.NoError(t, err) + } else if r.URL.Path == "/issuer/register" { + //requestBytes, err := io.ReadAll(r.Body) + //assert.NoError(t, err) + clientConfig := `{"client_id": "client1", "client_secret": "secret", "client_secret_expires_at": 0}` + w.WriteHeader(http.StatusCreated) + _, err := w.Write([]byte(clientConfig)) + assert.NoError(t, err) + } else if r.URL.Path == "/issuer/device_authz" { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"device_code": "1234", "user_code": "5678", "interval": 1, "verification_uri": "https://example.com", "expires_in": 20}`)) + assert.NoError(t, err) + } else if r.URL.Path == "/issuer/token" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"access_token": "token1234", "token_type": "jwt"}`)) + assert.NoError(t, err) + } else { + fmt.Println(r) + requestBytes, err := io.ReadAll(r.Body) + assert.NoError(t, err) + fmt.Println(string(requestBytes)) + w.WriteHeader(http.StatusInternalServerError) + } + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + myUrl = server.URL + + os.Setenv("PELICAN_SKIP_TERMINAL_CHECK", "password") + defer os.Unsetenv("PELICAN_SKIP_TERMINAL_CHECK") + viper.Set("Federation.DirectorURL", myUrl) + viper.Set("ConfigDir", t.TempDir()) + err := config.InitClient() + assert.NoError(t, err) + + // Call QueryDirector with the test server URL and a source path + testUrl, err := url.Parse("/test/foo/bar") + require.NoError(t, err) + token, err := CreateSharingUrl(testUrl, true) + assert.NoError(t, err) + assert.NotEmpty(t, token) + fmt.Println(token) +} diff --git a/cmd/config_mgr.go b/cmd/config_mgr.go index 9b873d23e..18c3a74c7 100644 --- a/cmd/config_mgr.go +++ b/cmd/config_mgr.go @@ -148,7 +148,11 @@ func addTokenSubcommands(tokenCmd *cobra.Command) { os.Exit(1) } - token, err := client.AcquireToken(&dest, namespace, isWrite) + opts := config.TokenGenerationOpts{Operation: config.TokenRead} + if isWrite { + opts.Operation = config.TokenWrite + } + token, err := client.AcquireToken(&dest, namespace, opts) if err != nil { fmt.Fprintln(os.Stderr, "Failed to get a token:", err) os.Exit(1) diff --git a/cmd/object_share.go b/cmd/object_share.go new file mode 100644 index 000000000..6dfbc812e --- /dev/null +++ b/cmd/object_share.go @@ -0,0 +1,75 @@ +/*************************************************************** + * + * Copyright (C) 2023, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package main + +import ( + "fmt" + "net/url" + + "github.com/pelicanplatform/pelican/client" + "github.com/pelicanplatform/pelican/config" + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +var ( + shareCmd = &cobra.Command{ + Use: "share {URL}", + Short: `Generate a string for sharing access to a namespace. +Note the sharing is based on prefixes; all object names matching the prefix will be accessible`, + RunE: shareMain, + } +) + +func init() { + flagSet := shareCmd.Flags() + flagSet.Bool("write", false, "Allow writes to the target prefix") + objectCmd.AddCommand(shareCmd) +} + +func shareMain(cmd *cobra.Command, args []string) error { + + err := config.InitClient() + if err != nil { + return errors.Wrap(err, "Failed to initialize the client") + } + + isWrite, err := cmd.Flags().GetBool("write") + if err != nil { + return errors.Wrap(err, "Unable to get the value of the --write flag") + } + + if len(args) == 0 { + return errors.New("A URL must be specified to share") + } + + objectUrl, err := url.Parse(args[0]) + if err != nil { + return errors.Wrapf(err, "Failed to parse '%v' as a URL", args[0]) + } + + token, err := client.CreateSharingUrl(objectUrl, isWrite) + if err != nil { + return errors.Wrapf(err, "Failed to create a sharing URL for %v", objectUrl.String()) + } + + objectUrl.RawQuery = "authz=" + token + fmt.Println(objectUrl.String()) + return nil +} diff --git a/config/config.go b/config/config.go index 1c6384bd0..53efe2ad3 100644 --- a/config/config.go +++ b/config/config.go @@ -46,35 +46,49 @@ import ( ) // Structs holding the OAuth2 state (and any other OSDF config needed) +type ( + TokenEntry struct { + Expiration int64 `yaml:"expiration"` + AccessToken string `yaml:"access_token"` + RefreshToken string `yaml:"refresh_token,omitempty"` + } -type TokenEntry struct { - Expiration int64 `yaml:"expiration"` - AccessToken string `yaml:"access_token"` - RefreshToken string `yaml:"refresh_token,omitempty"` -} + PrefixEntry struct { + // OSDF namespace prefix + Prefix string `yaml:"prefix"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + Tokens []TokenEntry `yaml:"tokens,omitempty"` + } -type PrefixEntry struct { - // OSDF namespace prefix - Prefix string `yaml:"prefix"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - Tokens []TokenEntry `yaml:"tokens,omitempty"` -} + OSDFConfig struct { -type OSDFConfig struct { + // Top-level OSDF object + OSDF struct { + // List of OAuth2 client configurations + OauthClient []PrefixEntry `yaml:"oauth_client,omitempty"` + } `yaml:"OSDF"` + } - // Top-level OSDF object - OSDF struct { - // List of OAuth2 client configurations - OauthClient []PrefixEntry `yaml:"oauth_client,omitempty"` - } `yaml:"OSDF"` -} + FederationDiscovery struct { + DirectorEndpoint string `json:"director_endpoint"` + NamespaceRegistrationEndpoint string `json:"namespace_registration_endpoint"` + JwksUri string `json:"jwks_uri"` + } -type FederationDiscovery struct { - DirectorEndpoint string `json:"director_endpoint"` - NamespaceRegistrationEndpoint string `json:"namespace_registration_endpoint"` - JwksUri string `json:"jwks_uri"` -} + TokenOperation int + + TokenGenerationOpts struct { + Operation TokenOperation + } +) + +const ( + TokenWrite TokenOperation = iota + TokenRead + TokenSharedWrite + TokenSharedRead +) var ( // Some of the unit tests probe behavior specific to OSDF vs Pelican. Hence, @@ -167,7 +181,8 @@ func DiscoverFederation() error { discoveryUrl.Path = path.Join(".well-known/pelican-configuration", federationUrl.Path) httpClient := http.Client{ - Timeout: time.Second * 5, + Transport: GetTransport(), + Timeout: time.Second * 5, } req, err := http.NewRequest(http.MethodGet, discoveryUrl.String(), nil) if err != nil { @@ -345,9 +360,8 @@ func InitConfig() { } } -func InitServer() error { +func initConfigDir() error { configDir := viper.GetString("ConfigDir") - viper.SetConfigType("yaml") if configDir == "" { if IsRootExecution() { configDir = "/etc/pelican" @@ -360,7 +374,15 @@ func InitServer() error { } viper.SetDefault("ConfigDir", configDir) } + return nil +} +func InitServer() error { + if err := initConfigDir(); err != nil { + return errors.Wrap(err, "Failed to initialize the server configuration") + } + configDir := viper.GetString("ConfigDir") + viper.SetConfigType("yaml") viper.SetDefault("Server.TLSCertificate", filepath.Join(configDir, "certificates", "tls.crt")) viper.SetDefault("Server.TLSKey", filepath.Join(configDir, "certificates", "tls.key")) viper.SetDefault("Server.TLSCAKey", filepath.Join(configDir, "certificates", "tlsca.key")) @@ -437,18 +459,15 @@ func InitServer() error { } func InitClient() error { - if IsRootExecution() { - viper.SetDefault("IssuerKey", "/etc/pelican/issuer.jwk") - } else { - configBase, err := getConfigBase() - if err != nil { - log.Warningln("No home directory found for user -- will check for configuration yaml in /etc/pelican/") - } - viper.SetDefault("IssuerKey", filepath.Join(configBase, "issuer.jwk")) + if err := initConfigDir(); err != nil { + log.Warningln("No home directory found for user -- will check for configuration yaml in /etc/pelican/") + viper.Set("ConfigDir", "/etc/pelican") } + configDir := viper.GetString("ConfigDir") + viper.SetDefault("IssuerKey", filepath.Join(configDir, "issuer.jwk")) + upper_prefix := GetPreferredPrefix() - lower_prefix := strings.ToLower(upper_prefix) viper.SetDefault("Client.StoppedTransferTimeout", 100) viper.SetDefault("Client.SlowTransferRampupTime", 100) @@ -463,7 +482,6 @@ func InitClient() error { viper.SetConfigName("config") viper.SetConfigType("yaml") - viper.AddConfigPath("$HOME/." + lower_prefix) err := viper.ReadInConfig() if err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { diff --git a/config/encrypted.go b/config/encrypted.go index ab5089baf..4c47934de 100644 --- a/config/encrypted.go +++ b/config/encrypted.go @@ -29,6 +29,7 @@ import ( "os" "path/filepath" + "github.com/spf13/viper" "github.com/youmark/pkcs8" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/nacl/box" @@ -42,22 +43,20 @@ import ( var setEmptyPassword = false func GetEncryptedConfigName() (string, error) { - if IsRootExecution() { - return "/etc/pelican/credentials/client-credentials.pem", nil + configDir := viper.GetString("ConfigDir") + if GetPreferredPrefix() == "PELICAN" || IsRootExecution() { + return filepath.Join(configDir, "credentials", "client-credentials.pem"), nil } - config_location := filepath.Join("pelican", "client-credentials.pem") - if GetPreferredPrefix() != "PELICAN" { - config_location = filepath.Join("osdf-client", "oauth2-client.pem") - } - config_root := os.Getenv("XDG_CONFIG_HOME") - if len(config_root) == 0 { + configLocation := filepath.Join("osdf-client", "oauth2-client.pem") + configRoot := os.Getenv("XDG_CONFIG_HOME") + if len(configRoot) == 0 { dirname, err := os.UserHomeDir() if err != nil { return "", err } - config_root = filepath.Join(dirname, ".config") + configRoot = filepath.Join(dirname, ".config") } - return filepath.Join(config_root, config_location), nil + return filepath.Join(configRoot, configLocation), nil } func EncryptedConfigExists() (bool, error) { diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 11eef5a6e..7346e1a09 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -20,7 +20,6 @@ package oauth2 import ( "context" - "errors" "fmt" "os" "path" @@ -28,6 +27,7 @@ import ( config "github.com/pelicanplatform/pelican/config" namespaces "github.com/pelicanplatform/pelican/namespaces" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -64,9 +64,9 @@ func trimPath(pathName string, maxDepth int) string { return "/" + path.Join(pathComponents[0:maxLength]...) } -func AcquireToken(issuerUrl string, entry *config.PrefixEntry, credentialGen *namespaces.CredentialGeneration, osdfPath string, isWrite bool) (*config.TokenEntry, error) { +func AcquireToken(issuerUrl string, entry *config.PrefixEntry, credentialGen *namespaces.CredentialGeneration, osdfPath string, opts config.TokenGenerationOpts) (*config.TokenEntry, error) { - if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 { + if fileInfo, _ := os.Stdout.Stat(); (len(os.Getenv(config.GetPreferredPrefix()+"_SKIP_TERMINAL_CHECK")) == 0) && ((fileInfo.Mode() & os.ModeCharDevice) == 0) { return nil, errors.New("This program must be run in a terminal to acquire a new token") } @@ -92,13 +92,13 @@ func AcquireToken(issuerUrl string, entry *config.PrefixEntry, credentialGen *na } // Potentially increase the coarseness of the token - if credentialGen.MaxScopeDepth != nil && *credentialGen.MaxScopeDepth >= 0 { + if opts.Operation != config.TokenSharedWrite && opts.Operation != config.TokenSharedRead && credentialGen.MaxScopeDepth != nil && *credentialGen.MaxScopeDepth >= 0 { pathCleaned = trimPath(pathCleaned, *credentialGen.MaxScopeDepth) } } var storageScope string - if isWrite { + if opts.Operation == config.TokenSharedWrite || opts.Operation == config.TokenWrite { storageScope = "storage.create:" } else { storageScope = "storage.read:" @@ -118,7 +118,7 @@ func AcquireToken(issuerUrl string, entry *config.PrefixEntry, credentialGen *na ctx := context.Background() deviceAuth, err := oauth2Config.AuthDevice(ctx) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "Failed to perform device code flow with URL %s", issuerInfo.DeviceAuthURL) } if len(deviceAuth.VerificationURIComplete) > 0 {