From d10fa1251e565aaffc250f54b56b1391647dc7a3 Mon Sep 17 00:00:00 2001 From: pgautier404 Date: Thu, 2 Mar 2023 08:59:45 -0800 Subject: [PATCH] fix: navigable auth and config packages (#217) * chore: clean up auth and config packages * working on config package * update config to correct new name * linter-prompted cleanup * fix naming to remove 'simple' * clean up after merge conflict * stop swallowing credential provider error in shared test context * PR feedback cleanup * don't use pointless pointer * trying to prevent linting from incorrectly rejecting my changes * Revert "trying to prevent linting from incorrectly rejecting my changes" This reverts commit b99f65598eaf956b9247c6b469bf714b3c4ec379. * check for blank endpoint values from user --- auth/credential_provider.go | 78 ++++++++++++++------------ auth/credential_provider_test.go | 48 ++++++++++++++++ config/config.go | 20 +++---- config/configurations.go | 34 ++++------- momento/simple_cache_client_test.go | 2 +- momento/test_helpers/shared_context.go | 7 ++- 6 files changed, 118 insertions(+), 71 deletions(-) diff --git a/auth/credential_provider.go b/auth/credential_provider.go index 895c755f..d1df9a29 100644 --- a/auth/credential_provider.go +++ b/auth/credential_provider.go @@ -10,11 +10,6 @@ import ( "github.com/momentohq/client-sdk-go/internal/momentoerrors" ) -type ResolveRequest struct { - AuthToken string - EndpointOverride string -} - type Endpoints struct { ControlEndpoint string CacheEndpoint string @@ -24,28 +19,58 @@ type CredentialProvider interface { GetAuthToken() string GetControlEndpoint() string GetCacheEndpoint() string + WithEndpoints(endpoints Endpoints) (CredentialProvider, error) } -type DefaultCredentialProvider struct { +type defaultCredentialProvider struct { authToken string controlEndpoint string cacheEndpoint string } -func (credentialProvider DefaultCredentialProvider) GetAuthToken() string { +func (credentialProvider defaultCredentialProvider) GetAuthToken() string { return credentialProvider.authToken } -func (credentialProvider DefaultCredentialProvider) GetControlEndpoint() string { +func (credentialProvider defaultCredentialProvider) GetControlEndpoint() string { return credentialProvider.controlEndpoint } -func (credentialProvider DefaultCredentialProvider) GetCacheEndpoint() string { +func (credentialProvider defaultCredentialProvider) GetCacheEndpoint() string { return credentialProvider.cacheEndpoint } -// NewEnvMomentoTokenProvider -// TODO: add overrides for endpoints +func FromEnvironmentVariable(envVar string) (CredentialProvider, error) { + credentialProvider, err := NewEnvMomentoTokenProvider(envVar) + if err != nil { + return nil, err + } + return credentialProvider, nil +} + +func FromString(authToken string) (CredentialProvider, error) { + credentialProvider, err := NewStringMomentoTokenProvider(authToken) + if err != nil { + return nil, err + } + return credentialProvider, nil +} + +// WithEndpoints overrides the cache and control endpoint URIs with those provided by the supplied Endpoints struct +// and returns a CredentialProvider with the new endpoint values. An endpoint supplied as an empty string is ignored +// and the existing value for that endpoint is retained. +func (credentialProvider defaultCredentialProvider) WithEndpoints(endpoints Endpoints) (CredentialProvider, error) { + if credentialProvider.cacheEndpoint != "" { + credentialProvider.cacheEndpoint = endpoints.CacheEndpoint + } + if credentialProvider.controlEndpoint != "" { + credentialProvider.controlEndpoint = endpoints.ControlEndpoint + } + return credentialProvider, nil +} + +// NewEnvMomentoTokenProvider constructor for a CredentialProvider using an environment variable to store an +// authentication token func NewEnvMomentoTokenProvider(envVariableName string) (CredentialProvider, error) { var authToken = os.Getenv(envVariableName) if authToken == "" { @@ -58,16 +83,14 @@ func NewEnvMomentoTokenProvider(envVariableName string) (CredentialProvider, err return NewStringMomentoTokenProvider(authToken) } -// NewStringMomentoTokenProvider -// TODO: add overrides for endpoints +// NewStringMomentoTokenProvider constructor for a CredentialProvider using a string containing an +// authentication token func NewStringMomentoTokenProvider(authToken string) (CredentialProvider, error) { - endpoints, err := resolve(&ResolveRequest{ - AuthToken: authToken, - }) + endpoints, err := getEndpointsFromToken(authToken) if err != nil { return nil, err } - provider := DefaultCredentialProvider{ + provider := defaultCredentialProvider{ authToken: authToken, controlEndpoint: endpoints.ControlEndpoint, cacheEndpoint: endpoints.CacheEndpoint, @@ -75,23 +98,6 @@ func NewStringMomentoTokenProvider(authToken string) (CredentialProvider, error) return provider, nil } -const ( - momentoControlEndpointPrefix = "control." - momentoCacheEndpointPrefix = "cache." - controlEndpointClaimID = "cp" - cacheEndpointClaimID = "c" -) - -func resolve(request *ResolveRequest) (*Endpoints, momentoerrors.MomentoSvcErr) { - if request.EndpointOverride != "" { - return &Endpoints{ - ControlEndpoint: momentoControlEndpointPrefix + request.EndpointOverride, - CacheEndpoint: momentoCacheEndpointPrefix + request.EndpointOverride, - }, nil - } - return getEndpointsFromToken(request.AuthToken) -} - func getEndpointsFromToken(authToken string) (*Endpoints, momentoerrors.MomentoSvcErr) { token, _, err := new(jwt.Parser).ParseUnverified(authToken, jwt.MapClaims{}) if err != nil { @@ -103,8 +109,8 @@ func getEndpointsFromToken(authToken string) (*Endpoints, momentoerrors.MomentoS } if claims, ok := token.Claims.(jwt.MapClaims); ok { return &Endpoints{ - ControlEndpoint: reflect.ValueOf(claims[controlEndpointClaimID]).String(), - CacheEndpoint: reflect.ValueOf(claims[cacheEndpointClaimID]).String(), + ControlEndpoint: reflect.ValueOf(claims["cp"]).String(), + CacheEndpoint: reflect.ValueOf(claims["c"]).String(), }, nil } return nil, momentoerrors.NewMomentoSvcErr( diff --git a/auth/credential_provider_test.go b/auth/credential_provider_test.go index 01fb6823..1d22ece4 100644 --- a/auth/credential_provider_test.go +++ b/auth/credential_provider_test.go @@ -2,6 +2,8 @@ package auth_test import ( "errors" + "fmt" + "os" "github.com/momentohq/client-sdk-go/auth" "github.com/momentohq/client-sdk-go/internal/momentoerrors" @@ -21,4 +23,50 @@ var _ = Describe("CredentialProvider", func() { Expect(momentoErr.Code()).To(Equal(momentoerrors.InvalidArgumentError)) } }) + + It("returns a credential provider from an environment variable via constructor", func() { + credentialProvider, err := auth.NewEnvMomentoTokenProvider("TEST_AUTH_TOKEN") + Expect(err).To(BeNil()) + Expect(credentialProvider.GetAuthToken()).To(Equal(os.Getenv("TEST_AUTH_TOKEN"))) + }) + + It("returns a credential provider from a string via constructor", func() { + credentialProvider, err := auth.NewStringMomentoTokenProvider(os.Getenv("TEST_AUTH_TOKEN")) + Expect(err).To(BeNil()) + Expect(credentialProvider.GetAuthToken()).To(Equal(os.Getenv("TEST_AUTH_TOKEN"))) + }) + + It("returns a credential provider from an environment variable via method", func() { + credentialProvider, err := auth.FromEnvironmentVariable("TEST_AUTH_TOKEN") + Expect(err).To(BeNil()) + Expect(credentialProvider.GetAuthToken()).To(Equal(os.Getenv("TEST_AUTH_TOKEN"))) + }) + + It("returns a credential provider from a string via method", func() { + credentialProvider, err := auth.FromString(os.Getenv("TEST_AUTH_TOKEN")) + Expect(err).To(BeNil()) + Expect(credentialProvider.GetAuthToken()).To(Equal(os.Getenv("TEST_AUTH_TOKEN"))) + }) + + It("overrides endpoints", func() { + credentialProvider, err := auth.NewEnvMomentoTokenProvider("TEST_AUTH_TOKEN") + Expect(err).To(BeNil()) + controlEndpoint := credentialProvider.GetControlEndpoint() + cacheEndpoint := credentialProvider.GetCacheEndpoint() + Expect(controlEndpoint).ToNot(BeEmpty()) + Expect(cacheEndpoint).ToNot(BeEmpty()) + + controlEndpoint = fmt.Sprintf("%s-overridden", controlEndpoint) + cacheEndpoint = fmt.Sprintf("%s-overridden", cacheEndpoint) + credentialProvider, err = credentialProvider.WithEndpoints( + auth.Endpoints{ + ControlEndpoint: controlEndpoint, + CacheEndpoint: cacheEndpoint, + }, + ) + Expect(err).To(BeNil()) + Expect(credentialProvider.GetControlEndpoint()).To(Equal(controlEndpoint)) + Expect(credentialProvider.GetCacheEndpoint()).To(Equal(cacheEndpoint)) + }) + }) diff --git a/config/config.go b/config/config.go index acc191e9..28429c2b 100644 --- a/config/config.go +++ b/config/config.go @@ -30,39 +30,39 @@ type Configuration interface { WithClientTimeout(clientTimeout time.Duration) Configuration } -type SimpleCacheConfiguration struct { +type cacheConfiguration struct { loggerFactory logger.MomentoLoggerFactory transportStrategy TransportStrategy } -func (s *SimpleCacheConfiguration) GetLoggerFactory() logger.MomentoLoggerFactory { +func (s *cacheConfiguration) GetLoggerFactory() logger.MomentoLoggerFactory { return s.loggerFactory } -func (s *SimpleCacheConfiguration) GetClientSideTimeout() time.Duration { +func (s *cacheConfiguration) GetClientSideTimeout() time.Duration { return s.transportStrategy.GetClientSideTimeout() } -func NewSimpleCacheConfiguration(props *ConfigurationProps) Configuration { - return &SimpleCacheConfiguration{ +func NewCacheConfiguration(props *ConfigurationProps) Configuration { + return &cacheConfiguration{ loggerFactory: props.LoggerFactory, transportStrategy: props.TransportStrategy, } } -func (s *SimpleCacheConfiguration) GetTransportStrategy() TransportStrategy { +func (s *cacheConfiguration) GetTransportStrategy() TransportStrategy { return s.transportStrategy } -func (s *SimpleCacheConfiguration) WithTransportStrategy(transportStrategy TransportStrategy) Configuration { - return &SimpleCacheConfiguration{ +func (s *cacheConfiguration) WithTransportStrategy(transportStrategy TransportStrategy) Configuration { + return &cacheConfiguration{ loggerFactory: s.loggerFactory, transportStrategy: transportStrategy, } } -func (s *SimpleCacheConfiguration) WithClientTimeout(clientTimeout time.Duration) Configuration { - return &SimpleCacheConfiguration{ +func (s *cacheConfiguration) WithClientTimeout(clientTimeout time.Duration) Configuration { + return &cacheConfiguration{ loggerFactory: s.loggerFactory, transportStrategy: s.transportStrategy.WithClientTimeout(clientTimeout), } diff --git a/config/configurations.go b/config/configurations.go index 4a33a4ba..6fead14e 100644 --- a/config/configurations.go +++ b/config/configurations.go @@ -6,44 +6,34 @@ import ( "github.com/momentohq/client-sdk-go/config/logger" ) -type Laptop struct { - Configuration -} - -func LatestLaptopConfig(loggerFactory ...logger.MomentoLoggerFactory) *Laptop { +func LaptopLatest(loggerFactory ...logger.MomentoLoggerFactory) Configuration { defaultLoggerFactory := logger.NewNoopMomentoLoggerFactory() if len(loggerFactory) != 0 { defaultLoggerFactory = loggerFactory[0] } - return &Laptop{ - Configuration: NewSimpleCacheConfiguration(&ConfigurationProps{ - LoggerFactory: defaultLoggerFactory, - TransportStrategy: NewStaticTransportStrategy(&TransportStrategyProps{ - GrpcConfiguration: NewStaticGrpcConfiguration(&GrpcConfigurationProps{ - deadline: 5 * time.Second, - }), + return NewCacheConfiguration(&ConfigurationProps{ + LoggerFactory: defaultLoggerFactory, + TransportStrategy: NewStaticTransportStrategy(&TransportStrategyProps{ + GrpcConfiguration: NewStaticGrpcConfiguration(&GrpcConfigurationProps{ + deadline: 5 * time.Second, }), }), - } -} - -type InRegion struct { - Configuration + }) } -func LatestInRegionConfig(loggerFactory ...logger.MomentoLoggerFactory) *InRegion { +func InRegionLatest(loggerFactory ...logger.MomentoLoggerFactory) Configuration { defaultLoggerFactory := logger.NewNoopMomentoLoggerFactory() if len(loggerFactory) != 0 { defaultLoggerFactory = loggerFactory[0] } - return &InRegion{ - Configuration: NewSimpleCacheConfiguration(&ConfigurationProps{ + return NewCacheConfiguration( + &ConfigurationProps{ LoggerFactory: defaultLoggerFactory, TransportStrategy: NewStaticTransportStrategy(&TransportStrategyProps{ GrpcConfiguration: NewStaticGrpcConfiguration(&GrpcConfigurationProps{ deadline: 1100 * time.Millisecond, }), }), - }), - } + }, + ) } diff --git a/momento/simple_cache_client_test.go b/momento/simple_cache_client_test.go index d8516cc7..08179141 100644 --- a/momento/simple_cache_client_test.go +++ b/momento/simple_cache_client_test.go @@ -33,7 +33,7 @@ var _ = Describe("CacheClient", func() { It(`errors on invalid timeout`, func() { badRequestTimeout := 0 * time.Second - sharedContext.Configuration = config.LatestLaptopConfig().WithClientTimeout(badRequestTimeout) + sharedContext.Configuration = config.LaptopLatest().WithClientTimeout(badRequestTimeout) Expect( NewCacheClient(sharedContext.Configuration, sharedContext.CredentialProvider, sharedContext.DefaultTtl), ).Error().To(HaveMomentoErrorCode(InvalidArgumentError)) diff --git a/momento/test_helpers/shared_context.go b/momento/test_helpers/shared_context.go index c05ed52b..037fda8e 100644 --- a/momento/test_helpers/shared_context.go +++ b/momento/test_helpers/shared_context.go @@ -25,9 +25,12 @@ func NewSharedContext() SharedContext { shared := SharedContext{} shared.Ctx = context.Background() - credentialProvider, _ := auth.NewEnvMomentoTokenProvider("TEST_AUTH_TOKEN") + credentialProvider, err := auth.NewEnvMomentoTokenProvider("TEST_AUTH_TOKEN") + if err != nil { + panic(err) + } shared.CredentialProvider = credentialProvider - shared.Configuration = config.LatestLaptopConfig() + shared.Configuration = config.LaptopLatest() shared.DefaultTtl = 3 * time.Second client, err := momento.NewCacheClient(shared.Configuration, shared.CredentialProvider, shared.DefaultTtl)