diff --git a/pkg/settings/cresettings/settings.go b/pkg/settings/cresettings/settings.go index 9efc145133..bab448cb05 100644 --- a/pkg/settings/cresettings/settings.go +++ b/pkg/settings/cresettings/settings.go @@ -40,6 +40,8 @@ func reinit() { if err != nil { log.Fatalf("failed to initialize settings: %v", err) } + } else { + DefaultGetter = nil } } diff --git a/pkg/settings/cresettings/settings_test.go b/pkg/settings/cresettings/settings_test.go index 95fedb083b..74fa55a89d 100644 --- a/pkg/settings/cresettings/settings_test.go +++ b/pkg/settings/cresettings/settings_test.go @@ -17,6 +17,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/contexts" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/settings" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" ) @@ -199,6 +200,10 @@ func TestDefaultGetter_SettingMap(t *testing.T) { } }`) reinit() // set default vars + + // ensure merged values; defaults must remain + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"]) + // confirm got, err = limit.GetOrDefault(ctx, DefaultGetter) require.NoError(t, err) require.False(t, got) @@ -233,13 +238,73 @@ func TestDefaultGetter_SettingMap(t *testing.T) { require.True(t, got) } -func TestChainAllows(t *testing.T) { - gl, err := limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t)}, Default.PerWorkflow.ChainAllowed) +func TestDefaultEnvVars(t *testing.T) { + // confirm defaults + require.Equal(t, "", Default.PerWorkflow.ChainAllowed.Values["1234"]) + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"]) + + t.Cleanup(reinit) // restore after + + // update defaults + t.Setenv(envNameSettingsDefault, `{ + "PerWorkflow": { + "ChainAllowed": { + "Values": { + "1234": "true" + } + } + } +}`) + reinit() // set default vars + + // confirm through Default + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["1234"]) + // without affecting others (they must merge) + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"]) + + // confirm through DefaultGetter + gl, err := limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: DefaultGetter}, Default.PerWorkflow.ChainAllowed) require.NoError(t, err) - ctx := contexts.WithCRE(t.Context(), contexts.CRE{Owner: "owner-id", Workflow: "foo"}) + ctx := contexts.WithCRE(t.Context(), contexts.CRE{Org: "foo", Owner: "owner-id", Workflow: "foo"}) + // defaults and global override allowed + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234))) + + // update overrides + t.Setenv(envNameSettingsDefault, "{}") + t.Setenv(envNameSettings, `{ + "global": { + "PerWorkflow": { + "ChainAllowed": { + "Values": { + "1234": "true" + } + } + } + } +}`) + + reinit() // set default vars + + // confirm through DefaultGetter + gl, err = limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: DefaultGetter}, Default.PerWorkflow.ChainAllowed) + require.NoError(t, err) + + // defaults and global override allowed + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234))) + + // confirm through an empty, but non-nil getter + getter, err := settings.NewJSONGetter([]byte(`{}`)) + require.NoError(t, err) + gl, err = limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: getter}, Default.PerWorkflow.ChainAllowed) + require.NoError(t, err) + // defaults and global override allowed assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246))) assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802))) - assert.ErrorIs(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234)), limits.ErrorNotAllowed{}) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234))) } diff --git a/pkg/settings/map.go b/pkg/settings/map.go index 786cadfe85..c758d9c169 100644 --- a/pkg/settings/map.go +++ b/pkg/settings/map.go @@ -45,16 +45,19 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er if err != nil { return s.Default.DefaultValue, fmt.Errorf("failed to get value from context: %w", err) } - if g == nil { + valueOrDefault := func() (T, error) { if str, ok := s.Values[strconv.FormatUint(k, 10)]; ok { value, err = s.Default.Parse(str) if err != nil { return s.Default.DefaultValue, err } - return + return value, nil } return s.Default.DefaultValue, nil } + if g == nil { + return valueOrDefault() + } valueKey := s.Default.Key + ".Values." + strconv.FormatUint(k, 10) defaultKey := s.Default.Key + ".Default" @@ -66,7 +69,7 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er } else if str != "" { value, err = s.Default.Parse(str) if err != nil { - return s.Default.DefaultValue, err + return valueOrDefault() } return } @@ -74,12 +77,12 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er // Default override str, err = g.GetScoped(ctx, s.Default.Scope, defaultKey) if err != nil || str == "" { - return s.Default.DefaultValue, err + return valueOrDefault() } value, err = s.Default.Parse(str) if err != nil { - return s.Default.DefaultValue, err + return valueOrDefault() } return }