From 70b0d5e2cc21c03ebb9a4f48882cd9520fcc35a4 Mon Sep 17 00:00:00 2001 From: Amanda Vialva Date: Fri, 25 Oct 2024 13:23:19 -0400 Subject: [PATCH] add more test and address comments --- master/internal/api_experiment.go | 9 +- .../postgres_task_config_policy.go | 61 ++++++------ .../postgres_task_config_policy_intg_test.go | 66 +++++++++++-- master/internal/configpolicy/utils.go | 41 ++++++++ master/internal/configpolicy/utils_test.go | 98 +++++++++++++++++++ master/internal/experiment.go | 45 ++++----- 6 files changed, 255 insertions(+), 65 deletions(-) diff --git a/master/internal/api_experiment.go b/master/internal/api_experiment.go index b9b116d0905c..a44a19f63abe 100644 --- a/master/internal/api_experiment.go +++ b/master/internal/api_experiment.go @@ -1219,7 +1219,6 @@ func (a *apiServer) PatchExperiment( activeConfig.SetResources(resources) } - // Only allow setting checkpoint storage if it is not specified as an invariant config. newCheckpointStorage := req.Experiment.CheckpointStorage if newCheckpointStorage != nil { @@ -1241,7 +1240,7 @@ func (a *apiServer) PatchExperiment( activeConfig.Workspace())) } - enforcedChkptConf, err := configpolicy.GetEnforcedConfig[expconf.CheckpointStorageConfig]( + enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig]( ctx, &w.ID, "invariant_config", "checkpoint_storage", model.ExperimentType) if err != nil { @@ -1262,6 +1261,7 @@ func (a *apiServer) PatchExperiment( if !ok { return nil, api.NotFoundErrs("experiment", strconv.Itoa(int(exp.Id)), true) } + if newResources.MaxSlots != nil { msg := sproto.SetGroupMaxSlots{MaxSlots: ptrs.Ptr(int(*newResources.MaxSlots))} e.SetGroupMaxSlots(msg) @@ -1488,15 +1488,14 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string fmt.Sprintf("override config must have single searcher type got '%s' instead", overrideName)) } - // Determine which workspace the experiment is in. + // Merge the config with the optionally specified invariant config specified by task config + // policies. w, err := getWorkspaceByConfig(activeConfig) if err != nil { return nil, false, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace %s", activeConfig.Workspace())) } - // Merge the config with the optionally specified invariant config specified by task config - // policies. configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs( context.TODO(), w.ID, mergedConfig) diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index e6b344868da0..988fc3ccb74e 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -15,8 +15,9 @@ import ( ) const ( - wkspIDQuery = "workspace_id = ?" - wkspIDGlobalQuery = "workspace_id IS ?" + wkspIDQuery = "workspace_id = ?" + wkspIDGlobalQuery = "workspace_id IS ?" + invalidPolicyTypeErr = "invalid policy type" // DefaultInvariantConfigStr is the default invariant config val used for tests. DefaultInvariantConfigStr = `{ "description": "random description", @@ -115,44 +116,48 @@ func DeleteConfigPolicies(ctx context.Context, return nil } -// GetEnforcedConfig gets the fields of the global invariant config or constraint if specified, and -// the workspace invariant config or constraint otherwise. If neither is specified, returns nil. -func GetEnforcedConfig[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T, +// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order +// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found. +// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string +// exactly as you wish for it to be accessed in the database. For example, if you want to access +// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT +// "resources -> max_slots". +// **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of +// object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so +// that when its pointer is returned, you get an object of type *int. +func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T, error, ) { if policyType != "invariant_config" && policyType != "constraints" { - return nil, fmt.Errorf("invalid policy type :%s", policyType) + return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType) } var confBytes []byte var conf T err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - globalField := tx.NewSelect(). - ColumnExpr("? -> ? AS globconf", bun.Safe(policyType), bun.Safe(field)). - Table("task_config_policies"). + var globalBytes []byte + err := tx.NewSelect().Table("task_config_policies"). + ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). Where("workspace_id IS NULL"). - Where("workload_type = ?", workloadType) - - wkspField := tx.NewSelect(). - ColumnExpr("? -> ? AS wkspconf", bun.Safe(policyType), bun.Safe(field)). - Table("task_config_policies"). - Where("workspace_id = '?'", wkspID). - Where("workload_type = ?", workloadType) - - both := tx.NewSelect().TableExpr("global_field"). - Join("NATURAL JOIN wksp_field") - - err := tx.NewSelect().ColumnExpr("coalesce(globconf, wkspconf)"). - With("global_field", globalField). - With("wksp_field", wkspField). - Table("both").With("both", both). - Scan(ctx, &confBytes) - if err != nil { + Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes) + if err == nil && len(globalBytes) > 0 { + confBytes = globalBytes + } + if err != nil && err != sql.ErrNoRows { return err } - return nil + + var wkspBytes []byte + err = tx.NewSelect().Table("task_config_policies"). + ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). + Where("workspace_id = ?", wkspID). + Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes) + if err == nil && len(globalBytes) == 0 { + confBytes = wkspBytes + } + return err }) - if err == sql.ErrNoRows { + if err == sql.ErrNoRows || len(confBytes) == 0 { return nil, nil } if err != nil { diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index 8b44700d8eac..9a35dfee6e8d 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -664,8 +664,8 @@ func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *mo func TestGetEnforcedConfig(t *testing.T) { ctx := context.Background() require.NoError(t, etc.SetRootPath(db.RootFromDB)) - pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) - defer cleanup() + pgDB, _ := db.MustResolveNewPostgresDatabase(t) + // defer cleanup() db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) user := db.RequireMockUser(t, pgDB) @@ -710,10 +710,12 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - checkpointStorage, err := GetEnforcedConfig[expconf.CheckpointStorageConfig](ctx, &w.ID, + checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID, "invariant_config", "'checkpoint_storage'", model.ExperimentType) require.NoError(t, err) require.NotNil(t, checkpointStorage) + + // global config enforced? require.Equal(t, expconf.CheckpointStorageConfigV0{ RawSharedFSConfig: &expconf.SharedFSConfigV0{ RawHostPath: ptrs.Ptr("global_host_path"), @@ -749,10 +751,12 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "invariant_config", + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", "'resources' -> 'max_slots'", model.ExperimentType) require.NoError(t, err) require.NotNil(t, maxSlots) + + // workspace config enforced? require.Equal(t, 15, *maxSlots) }) @@ -788,10 +792,12 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints", + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", "'resources' -> 'max_slots'", model.ExperimentType) require.NoError(t, err) require.NotNil(t, maxSlots) + + // global constraint enforced? require.Equal(t, 25, *maxSlots) }) @@ -823,10 +829,54 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints", + priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", "'priority_limit'", model.ExperimentType) require.NoError(t, err) - require.NotNil(t, maxSlots) - require.Equal(t, 40, *maxSlots) + require.NotNil(t, priority) + + // global constraint enforced? + require.Equal(t, 40, *priority) + }) + + t.Run("priority constraints wksp", func(t *testing.T) { + // delete global config policies + err = DeleteConfigPolicies(ctx, nil, model.ExperimentType) + require.NoError(t, err) + + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: &wkspConstraints, + }) + require.NoError(t, err) + + priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", + "'priority_limit'", model.ExperimentType) + require.NoError(t, err) + require.NotNil(t, priority) + + // workspace constraint enforced? + require.Equal(t, 50, *priority) + }) + + t.Run("field not set in config", func(t *testing.T) { + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", + "'max_restarts'", model.ExperimentType) + require.NoError(t, err) + require.Nil(t, maxRestarts) + }) + + t.Run("nonexistent constraints field", func(t *testing.T) { + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", + "'max_restarts'", model.ExperimentType) + require.NoError(t, err) + require.Nil(t, maxRestarts) + }) + + t.Run("invalid policy type", func(t *testing.T) { + _, err := GetConfigPolicyField[int](ctx, &w.ID, "bad policy", + "'debug'", model.ExperimentType) + require.ErrorContains(t, err, invalidPolicyTypeErr) }) } diff --git a/master/internal/configpolicy/utils.go b/master/internal/configpolicy/utils.go index 0b503e84e31e..1bcb310c29c4 100644 --- a/master/internal/configpolicy/utils.go +++ b/master/internal/configpolicy/utils.go @@ -2,6 +2,7 @@ package configpolicy import ( "bytes" + "context" "encoding/json" "fmt" "reflect" @@ -27,6 +28,9 @@ const ( InvalidNTSCConfigPolicyErr = "invalid ntsc config policy" // NotSupportedConfigPolicyErr is the error reported when admins attempt to set NTSC invariant config. NotSupportedConfigPolicyErr = "not supported" + // SlotsReqTooHighErr is the error reported when the requested slots violates the max slots + // constraint. + SlotsReqTooHighErr = "requested slots is violates max slots constraint" ) // ConfigPolicyWarning logs a warning for the configuration policy component. @@ -298,3 +302,40 @@ func configPolicyOverlap(config1, config2 interface{}) { } } } + +// CanSetMaxSlots returns true if the slots requested don't violate a constraint. It returns the +// enforced max slots for the workspace if that's set as an invariant config, and returns the +// requested max slots otherwise. Returns an error when max slots is not set as an invariant config +// and the requested max slots violates the constriant. +func CanSetMaxSlots(slotsReq *int, wkspID int) (bool, *int, error) { + if slotsReq == nil { + return true, slotsReq, nil + } + enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID, + "invariant_config", + "'resources' -> 'max_slots'", model.ExperimentType) + if err != nil { + return false, nil, err + } + + if enforcedMaxSlots != nil { + return true, enforcedMaxSlots, nil + } + + maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID, + "constraints", + "'resources' -> 'max_slots'", model.ExperimentType) + if err != nil { + return false, nil, err + } + + var canSetReqSlots bool + if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit { + canSetReqSlots = true + } + if !canSetReqSlots { + return false, nil, fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit) + } + + return true, slotsReq, nil +} diff --git a/master/internal/configpolicy/utils_test.go b/master/internal/configpolicy/utils_test.go index 5ade64494ed6..eab375eb808c 100644 --- a/master/internal/configpolicy/utils_test.go +++ b/master/internal/configpolicy/utils_test.go @@ -1,11 +1,14 @@ package configpolicy import ( + "context" "testing" "github.com/stretchr/testify/require" "gotest.tools/assert" + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/etc" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas/expconf" @@ -610,3 +613,98 @@ invariant_config: require.Error(t, err) // stub CM-493 }) } + +func TestCanSetMaxSlots(t *testing.T) { + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + + user := db.RequireMockUser(t, pgDB) + ctx := context.Background() + w := createWorkspaceWithUser(ctx, t, user.ID) + t.Run("nil slots request", func(t *testing.T) { + canSetReqSlots, slots, err := CanSetMaxSlots(nil, w.ID) + require.NoError(t, err) + require.Nil(t, slots) + require.True(t, canSetReqSlots) + }) + + err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + InvariantConfig: ptrs.Ptr(` +{ + "resources": { + "max_slots": 13 + } +} +`), + Constraints: ptrs.Ptr(` +{ + "resources": { + "max_slots": 13 + } +} +`), + }) + require.NoError(t, err) + + t.Run("slots different than config higher", func(t *testing.T) { + canSetReqSlots, slots, err := CanSetMaxSlots(ptrs.Ptr(15), w.ID) + require.NoError(t, err) + require.True(t, canSetReqSlots) + require.NotNil(t, slots) + require.Equal(t, 13, *slots) + }) + + t.Run("slots different than config lower", func(t *testing.T) { + canSetReqSlots, slots, err := CanSetMaxSlots(ptrs.Ptr(10), w.ID) + require.NoError(t, err) + require.True(t, canSetReqSlots) + require.NotNil(t, slots) + require.Equal(t, 13, *slots) + }) + + t.Run("just constarints slots higher", func(t *testing.T) { + err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: ptrs.Ptr(` + { + "resources": { + "max_slots": 23 + } + } + `), + }) + + canSetReqSlots, slots, err := CanSetMaxSlots(ptrs.Ptr(25), w.ID) + require.ErrorContains(t, err, SlotsReqTooHighErr) + require.False(t, canSetReqSlots) + require.Nil(t, slots) + }) + + t.Run("just constarints slots lower", func(t *testing.T) { + err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: ptrs.Ptr(` + { + "resources": { + "max_slots": 23 + } + } + `), + }) + + canSetReqSlots, slots, err := CanSetMaxSlots(ptrs.Ptr(20), w.ID) + require.NoError(t, err) + require.True(t, canSetReqSlots) + require.NotNil(t, slots) + require.Equal(t, 20, *slots) + }) +} diff --git a/master/internal/experiment.go b/master/internal/experiment.go index a7ed54c02424..f52d77195727 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -411,40 +411,20 @@ func (e *internalExperiment) PatchTrialState(msg experiment.PatchTrialState) err func (e *internalExperiment) SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) { e.mu.Lock() defer e.mu.Unlock() - // Only allow max slots changes if it is not specified as an invariant config or enforced as a - // constraint. w, err := getWorkspaceByConfig(e.activeConfig) if err != nil { log.Warnf("unable to set max slots") return } - enforcedMaxSlots, err := configpolicy.GetEnforcedConfig[int](context.TODO(), &w.ID, - "invariant_config", - "'resources' -> 'max_slots'", model.ExperimentType) - if err != nil { - log.Warnf("unable to set max slots") - return - } - - if enforcedMaxSlots != nil { - msg.MaxSlots = enforcedMaxSlots - } - maxSlotsLimit, err := configpolicy.GetEnforcedConfig[int](context.TODO(), &w.ID, - "constraints", - "'resources' -> 'max_slots'", model.ExperimentType) - if err != nil { - log.Warnf("unable to set max slots") - return - } - - if enforcedMaxSlots == nil && maxSlotsLimit != nil && msg.MaxSlots != nil && - *msg.MaxSlots > *maxSlotsLimit { - log.Warnf("unable to set max slots") + canContinue, slots, err := configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) + if !canContinue { + log.Warnf("unable to set max slots: %s", err.Error()) return } + msg.MaxSlots = slots resources := e.activeConfig.Resources() resources.SetMaxSlots(msg.MaxSlots) e.activeConfig.SetResources(resources) @@ -1133,6 +1113,23 @@ func (e *internalExperiment) setPriority(priority *int, forward bool) (err error } func (e *internalExperiment) setWeight(weight float64) error { + // Only set requested weight if it is not set in an invariant config. + w, err := getWorkspaceByConfig(e.activeConfig) + if err != nil { + log.Warnf("unable to set max slots") + return fmt.Errorf("unable to set weight, ") + } + enforcedWeight, err := configpolicy.GetConfigPolicyField[float64](context.TODO(), &w.ID, + "invariant_config", + "'resources' -> 'weight'", model.ExperimentType) + if err != nil { + log.Warnf("unable to set weight %v", weight) + return nil + } + if enforcedWeight != nil { + weight = *enforcedWeight + } + resources := e.activeConfig.Resources() oldWeight := resources.Weight() resources.SetWeight(weight)