diff --git a/master/internal/api_experiment.go b/master/internal/api_experiment.go index 9fe3daa6c81..a44a19f63ab 100644 --- a/master/internal/api_experiment.go +++ b/master/internal/api_experiment.go @@ -1218,6 +1218,7 @@ func (a *apiServer) PatchExperiment( } activeConfig.SetResources(resources) } + newCheckpointStorage := req.Experiment.CheckpointStorage if newCheckpointStorage != nil { @@ -1231,6 +1232,23 @@ func (a *apiServer) PatchExperiment( storage.SetSaveTrialBest(int(newCheckpointStorage.SaveTrialBest)) storage.SetSaveTrialLatest(int(newCheckpointStorage.SaveTrialLatest)) activeConfig.SetCheckpointStorage(storage) + + // Only allow checkpoint storage changes if it is not specified as an invariant config. + w, err := getWorkspaceByConfig(activeConfig) + if err != nil { + return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace %s", + activeConfig.Workspace())) + } + + enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig]( + ctx, &w.ID, "invariant_config", "checkpoint_storage", + model.ExperimentType) + if err != nil { + return nil, fmt.Errorf("unable to fetch task config policies: %w", err) + } + if enforcedChkptConf != nil { + activeConfig.SetCheckpointStorage(*enforcedChkptConf) + } } // `patch` represents the allowed mutations that can be performed on an experiment, in JSON @@ -1470,20 +1488,16 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string fmt.Sprintf("override config must have single searcher type got '%s' instead", overrideName)) } - // Determine which workspace the experiment is in. - wkspName := activeConfig.Workspace() - if wkspName == "" { - wkspName = model.DefaultWorkspaceName - } - ctx := context.TODO() - w, err := workspace.WorkspaceByName(ctx, wkspName) + // Merge the config with the optionally specified invariant config specified by task config + // policies. + w, err := getWorkspaceByConfig(activeConfig) if err != nil { return nil, false, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace %s", activeConfig.Workspace())) } - // Merge the config with the optionally specified invariant config specified by task config - // policies. - configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs(ctx, + + configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs( + context.TODO(), w.ID, mergedConfig) if err != nil { return nil, false, @@ -1499,6 +1513,15 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string return bytes.([]byte), isSingle, nil } +func getWorkspaceByConfig(config expconf.ExperimentConfig) (*model.Workspace, error) { + wkspName := config.Workspace() + if wkspName == "" { + wkspName = model.DefaultWorkspaceName + } + ctx := context.TODO() + return workspace.WorkspaceByName(ctx, wkspName) +} + var errContinueHPSearchCompleted = status.Error(codes.FailedPrecondition, "experiment has been completed, cannot continue this experiment") diff --git a/master/internal/api_experiment_intg_test.go b/master/internal/api_experiment_intg_test.go index a9a9d620e91..8b07e8c3e05 100644 --- a/master/internal/api_experiment_intg_test.go +++ b/master/internal/api_experiment_intg_test.go @@ -2347,3 +2347,27 @@ func TestDeleteExperimentsFiltered(t *testing.T) { } t.Error("expected experiments to delete after 15 seconds and they did not") } + +func TestGetWorkspaceByConfig(t *testing.T) { + api, _, ctx := setupAPITest(t, nil) + resp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{ + Name: uuid.New().String(), + }) + require.NoError(t, err) + wkspName := &resp.Workspace.Name + + t.Run("no workspace name", func(t *testing.T) { + w, err := getWorkspaceByConfig(expconf.ExperimentConfig{RawWorkspace: ptrs.Ptr("")}) + require.NoError(t, err) + + // Verify we get the Uncategorized workspace. + require.Equal(t, 1, w.ID) + }) + t.Run("has workspace name", func(t *testing.T) { + w, err := getWorkspaceByConfig(expconf.ExperimentConfig{ + RawWorkspace: wkspName, + }) + require.NoError(t, err) + require.Equal(t, *wkspName, w.Name) + }) +} diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index dea01bf7e42..988fc3ccb74 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -3,6 +3,7 @@ package configpolicy import ( "context" "database/sql" + "encoding/json" "fmt" "strings" @@ -14,8 +15,9 @@ import ( ) const ( - wkspIDQuery = "workspace_id = ?" - wkspIDGlobalQuery = "workspace_id IS ?" + wkspIDQuery = "workspace_id = ?" + wkspIDGlobalQuery = "workspace_id IS ?" + invalidPolicyTypeErr = "invalid policy type" // DefaultInvariantConfigStr is the default invariant config val used for tests. DefaultInvariantConfigStr = `{ "description": "random description", @@ -113,3 +115,59 @@ func DeleteConfigPolicies(ctx context.Context, } return nil } + +// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order +// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found. +// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string +// exactly as you wish for it to be accessed in the database. For example, if you want to access +// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT +// "resources -> max_slots". +// **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of +// object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so +// that when its pointer is returned, you get an object of type *int. +func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T, + error, +) { + if policyType != "invariant_config" && policyType != "constraints" { + return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType) + } + + var confBytes []byte + var conf T + err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + var globalBytes []byte + err := tx.NewSelect().Table("task_config_policies"). + ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). + Where("workspace_id IS NULL"). + Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes) + if err == nil && len(globalBytes) > 0 { + confBytes = globalBytes + } + if err != nil && err != sql.ErrNoRows { + return err + } + + var wkspBytes []byte + err = tx.NewSelect().Table("task_config_policies"). + ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). + Where("workspace_id = ?", wkspID). + Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes) + if err == nil && len(globalBytes) == 0 { + confBytes = wkspBytes + } + return err + }) + if err == sql.ErrNoRows || len(confBytes) == 0 { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("error getting config field %s: %w", field, err) + } + + err = json.Unmarshal(confBytes, &conf) + if err != nil { + return nil, fmt.Errorf("error unmarshaling config field: %w", err) + } + + return &conf, nil +} diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index 9035e6cd7c2..9a35dfee6e8 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -17,6 +17,7 @@ import ( "github.com/determined-ai/determined/master/pkg/etc" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/ptrs" + "github.com/determined-ai/determined/master/pkg/schemas/expconf" "github.com/stretchr/testify/require" ) @@ -659,3 +660,223 @@ func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *mo require.Equal(t, expJSONMap, actJSONMap) } } + +func TestGetEnforcedConfig(t *testing.T) { + ctx := context.Background() + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, _ := db.MustResolveNewPostgresDatabase(t) + // defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + + user := db.RequireMockUser(t, pgDB) + + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + + globalConf := ` +{ + "checkpoint_storage": { + "type": "shared_fs", + "host_path": "global_host_path", + "container_path": "global_container_path" + } +} +` + wkspConf := ` +{ + "checkpoint_storage": { + "type": "shared_fs", + "host_path": "wksp_host_path", + "container_path": "wksp_container_path", + "checkpoint_path": "wksp_checkpoint_path" + } +} +` + + t.Run("checkpoint storage config", func(t *testing.T) { + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + InvariantConfig: &globalConf, + }) + require.NoError(t, err) + + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + InvariantConfig: &wkspConf, + }) + require.NoError(t, err) + + checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID, + "invariant_config", "'checkpoint_storage'", model.ExperimentType) + require.NoError(t, err) + require.NotNil(t, checkpointStorage) + + // global config enforced? + require.Equal(t, expconf.CheckpointStorageConfigV0{ + RawSharedFSConfig: &expconf.SharedFSConfigV0{ + RawHostPath: ptrs.Ptr("global_host_path"), + RawContainerPath: ptrs.Ptr("global_container_path"), + }, + }, *checkpointStorage) + }) + + globalConf = `{ + "debug": true +}` + wkspConf = ` + { + "resources": { + "max_slots": 15 + } + } +` + + t.Run("max slots config", func(t *testing.T) { + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + InvariantConfig: &globalConf, + }) + require.NoError(t, err) + + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + InvariantConfig: &wkspConf, + }) + require.NoError(t, err) + + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", + "'resources' -> 'max_slots'", model.ExperimentType) + require.NoError(t, err) + require.NotNil(t, maxSlots) + + // workspace config enforced? + require.Equal(t, 15, *maxSlots) + }) + + globalConstraints := ` + { + "resources": { + "max_slots": 25 + } + } +` + + wkspConstraints := ` + { + "resources": { + "max_slots": 20 + } + } +` + + t.Run("max slots constraints", func(t *testing.T) { + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: &globalConstraints, + }) + require.NoError(t, err) + + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: &wkspConstraints, + }) + require.NoError(t, err) + + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", + "'resources' -> 'max_slots'", model.ExperimentType) + require.NoError(t, err) + require.NotNil(t, maxSlots) + + // global constraint enforced? + require.Equal(t, 25, *maxSlots) + }) + + globalConstraints = ` + { + "priority_limit": 40 + } +` + + wkspConstraints = ` + { + "priority_limit": 50 + } +` + + t.Run("priority constraints", func(t *testing.T) { + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: &globalConstraints, + }) + require.NoError(t, err) + + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: &wkspConstraints, + }) + require.NoError(t, err) + + priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", + "'priority_limit'", model.ExperimentType) + require.NoError(t, err) + require.NotNil(t, priority) + + // global constraint enforced? + require.Equal(t, 40, *priority) + }) + + t.Run("priority constraints wksp", func(t *testing.T) { + // delete global config policies + err = DeleteConfigPolicies(ctx, nil, model.ExperimentType) + require.NoError(t, err) + + err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: &wkspConstraints, + }) + require.NoError(t, err) + + priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", + "'priority_limit'", model.ExperimentType) + require.NoError(t, err) + require.NotNil(t, priority) + + // workspace constraint enforced? + require.Equal(t, 50, *priority) + }) + + t.Run("field not set in config", func(t *testing.T) { + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", + "'max_restarts'", model.ExperimentType) + require.NoError(t, err) + require.Nil(t, maxRestarts) + }) + + t.Run("nonexistent constraints field", func(t *testing.T) { + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", + "'max_restarts'", model.ExperimentType) + require.NoError(t, err) + require.Nil(t, maxRestarts) + }) + + t.Run("invalid policy type", func(t *testing.T) { + _, err := GetConfigPolicyField[int](ctx, &w.ID, "bad policy", + "'debug'", model.ExperimentType) + require.ErrorContains(t, err, invalidPolicyTypeErr) + }) +} diff --git a/master/internal/configpolicy/utils.go b/master/internal/configpolicy/utils.go index 0b503e84e31..7f08edd51ae 100644 --- a/master/internal/configpolicy/utils.go +++ b/master/internal/configpolicy/utils.go @@ -2,6 +2,7 @@ package configpolicy import ( "bytes" + "context" "encoding/json" "fmt" "reflect" @@ -27,6 +28,9 @@ const ( InvalidNTSCConfigPolicyErr = "invalid ntsc config policy" // NotSupportedConfigPolicyErr is the error reported when admins attempt to set NTSC invariant config. NotSupportedConfigPolicyErr = "not supported" + // SlotsReqTooHighErr is the error reported when the requested slots violates the max slots + // constraint. + SlotsReqTooHighErr = "requested slots is violates max slots constraint" ) // ConfigPolicyWarning logs a warning for the configuration policy component. @@ -298,3 +302,40 @@ func configPolicyOverlap(config1, config2 interface{}) { } } } + +// CanSetMaxSlots returns true if the slots requested don't violate a constraint. It returns the +// enforced max slots for the workspace if that's set as an invariant config, and returns the +// requested max slots otherwise. Returns an error when max slots is not set as an invariant config +// and the requested max slots violates the constriant. +func CanSetMaxSlots(slotsReq *int, wkspID int) (*int, error) { + if slotsReq == nil { + return slotsReq, nil + } + enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID, + "invariant_config", + "'resources' -> 'max_slots'", model.ExperimentType) + if err != nil { + return nil, err + } + + if enforcedMaxSlots != nil { + return enforcedMaxSlots, nil + } + + maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID, + "constraints", + "'resources' -> 'max_slots'", model.ExperimentType) + if err != nil { + return nil, err + } + + var canSetReqSlots bool + if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit { + canSetReqSlots = true + } + if !canSetReqSlots { + return nil, fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit) + } + + return slotsReq, nil +} diff --git a/master/internal/configpolicy/utils_test.go b/master/internal/configpolicy/utils_test.go index 5ade64494ed..dd5f80820d2 100644 --- a/master/internal/configpolicy/utils_test.go +++ b/master/internal/configpolicy/utils_test.go @@ -1,11 +1,14 @@ package configpolicy import ( + "context" "testing" "github.com/stretchr/testify/require" "gotest.tools/assert" + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/etc" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas/expconf" @@ -610,3 +613,95 @@ invariant_config: require.Error(t, err) // stub CM-493 }) } + +func TestCanSetMaxSlots(t *testing.T) { + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + + user := db.RequireMockUser(t, pgDB) + ctx := context.Background() + w := createWorkspaceWithUser(ctx, t, user.ID) + t.Run("nil slots request", func(t *testing.T) { + slots, err := CanSetMaxSlots(nil, w.ID) + require.NoError(t, err) + require.Nil(t, slots) + }) + + err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + InvariantConfig: ptrs.Ptr(` +{ + "resources": { + "max_slots": 13 + } +} +`), + Constraints: ptrs.Ptr(` +{ + "resources": { + "max_slots": 13 + } +} +`), + }) + require.NoError(t, err) + + t.Run("slots different than config higher", func(t *testing.T) { + slots, err := CanSetMaxSlots(ptrs.Ptr(15), w.ID) + require.NoError(t, err) + require.NotNil(t, slots) + require.Equal(t, 13, *slots) + }) + + t.Run("slots different than config lower", func(t *testing.T) { + slots, err := CanSetMaxSlots(ptrs.Ptr(10), w.ID) + require.NoError(t, err) + require.NotNil(t, slots) + require.Equal(t, 13, *slots) + }) + + t.Run("just constraints slots higher", func(t *testing.T) { + err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: ptrs.Ptr(` + { + "resources": { + "max_slots": 23 + } + } + `), + }) + require.NoError(t, err) + + slots, err := CanSetMaxSlots(ptrs.Ptr(25), w.ID) + require.ErrorContains(t, err, SlotsReqTooHighErr) + require.Nil(t, slots) + }) + + t.Run("just constraints slots lower", func(t *testing.T) { + err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ + WorkspaceID: &w.ID, + WorkloadType: model.ExperimentType, + LastUpdatedBy: user.ID, + Constraints: ptrs.Ptr(` + { + "resources": { + "max_slots": 23 + } + } + `), + }) + require.NoError(t, err) + + slots, err := CanSetMaxSlots(ptrs.Ptr(20), w.ID) + require.NoError(t, err) + require.NotNil(t, slots) + require.Equal(t, 20, *slots) + }) +} diff --git a/master/internal/experiment.go b/master/internal/experiment.go index 66f416bb46c..01f1c955176 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -10,7 +10,7 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" "github.com/shopspring/decimal" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" "github.com/uptrace/bun" "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" @@ -65,7 +65,7 @@ type ( activeConfig expconf.ExperimentConfig db *internaldb.PgDB rm rm.ResourceManager - syslog *logrus.Entry + syslog *log.Entry searcher *searcher.Searcher warmStartCheckpoint *model.Checkpoint continueTrials bool @@ -168,7 +168,7 @@ func newExperiment( activeConfig: activeConfig, db: m.db, rm: m.rm, - syslog: logrus.WithFields(logrus.Fields{ + syslog: log.WithFields(log.Fields{ "component": "experiment", "job-id": expModel.JobID, "experiment-id": expModel.ID, @@ -412,6 +412,19 @@ func (e *internalExperiment) SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) { e.mu.Lock() defer e.mu.Unlock() + w, err := getWorkspaceByConfig(e.activeConfig) + if err != nil { + log.Warnf("unable to set max slots") + return + } + + slots, err := configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) + if err != nil { + log.Warnf("unable to set max slots: %s", err.Error()) + return + } + + msg.MaxSlots = slots resources := e.activeConfig.Resources() resources.SetMaxSlots(msg.MaxSlots) e.activeConfig.SetResources(resources) @@ -1100,6 +1113,21 @@ func (e *internalExperiment) setPriority(priority *int, forward bool) (err error } func (e *internalExperiment) setWeight(weight float64) error { + // Only set requested weight if it is not set in an invariant config. + w, err := getWorkspaceByConfig(e.activeConfig) + if err != nil { + return fmt.Errorf("error getting workspace: %w", err) + } + enforcedWeight, err := configpolicy.GetConfigPolicyField[float64](context.TODO(), &w.ID, + "invariant_config", + "'resources' -> 'weight'", model.ExperimentType) + if err != nil { + return fmt.Errorf("error checking against config policies: %w", err) + } + if enforcedWeight != nil { + weight = *enforcedWeight + } + resources := e.activeConfig.Resources() oldWeight := resources.Weight() resources.SetWeight(weight)