Skip to content

Commit

Permalink
add more test and address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 committed Oct 25, 2024
1 parent afe5e43 commit 70b0d5e
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 65 deletions.
9 changes: 4 additions & 5 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,6 @@ func (a *apiServer) PatchExperiment(
activeConfig.SetResources(resources)
}

// Only allow setting checkpoint storage if it is not specified as an invariant config.
newCheckpointStorage := req.Experiment.CheckpointStorage

if newCheckpointStorage != nil {
Expand All @@ -1241,7 +1240,7 @@ func (a *apiServer) PatchExperiment(
activeConfig.Workspace()))
}

enforcedChkptConf, err := configpolicy.GetEnforcedConfig[expconf.CheckpointStorageConfig](
enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig](
ctx, &w.ID, "invariant_config", "checkpoint_storage",
model.ExperimentType)
if err != nil {
Expand All @@ -1262,6 +1261,7 @@ func (a *apiServer) PatchExperiment(
if !ok {
return nil, api.NotFoundErrs("experiment", strconv.Itoa(int(exp.Id)), true)
}

if newResources.MaxSlots != nil {
msg := sproto.SetGroupMaxSlots{MaxSlots: ptrs.Ptr(int(*newResources.MaxSlots))}
e.SetGroupMaxSlots(msg)
Expand Down Expand Up @@ -1488,15 +1488,14 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string
fmt.Sprintf("override config must have single searcher type got '%s' instead", overrideName))
}

// Determine which workspace the experiment is in.
// Merge the config with the optionally specified invariant config specified by task config
// policies.
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, false, status.Errorf(codes.Internal,
fmt.Sprintf("failed to get workspace %s", activeConfig.Workspace()))
}

// Merge the config with the optionally specified invariant config specified by task config
// policies.
configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs(
context.TODO(),
w.ID, mergedConfig)
Expand Down
61 changes: 33 additions & 28 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import (
)

const (
wkspIDQuery = "workspace_id = ?"
wkspIDGlobalQuery = "workspace_id IS ?"
wkspIDQuery = "workspace_id = ?"
wkspIDGlobalQuery = "workspace_id IS ?"
invalidPolicyTypeErr = "invalid policy type"
// DefaultInvariantConfigStr is the default invariant config val used for tests.
DefaultInvariantConfigStr = `{
"description": "random description",
Expand Down Expand Up @@ -115,44 +116,48 @@ func DeleteConfigPolicies(ctx context.Context,
return nil
}

// GetEnforcedConfig gets the fields of the global invariant config or constraint if specified, and
// the workspace invariant config or constraint otherwise. If neither is specified, returns nil.
func GetEnforcedConfig[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T,
// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order
// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found.
// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string
// exactly as you wish for it to be accessed in the database. For example, if you want to access
// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT
// "resources -> max_slots".
// **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of
// object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so
// that when its pointer is returned, you get an object of type *int.
func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T,
error,
) {
if policyType != "invariant_config" && policyType != "constraints" {
return nil, fmt.Errorf("invalid policy type :%s", policyType)
return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType)
}

var confBytes []byte
var conf T
err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
globalField := tx.NewSelect().
ColumnExpr("? -> ? AS globconf", bun.Safe(policyType), bun.Safe(field)).
Table("task_config_policies").
var globalBytes []byte
err := tx.NewSelect().Table("task_config_policies").
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id IS NULL").
Where("workload_type = ?", workloadType)

wkspField := tx.NewSelect().
ColumnExpr("? -> ? AS wkspconf", bun.Safe(policyType), bun.Safe(field)).
Table("task_config_policies").
Where("workspace_id = '?'", wkspID).
Where("workload_type = ?", workloadType)

both := tx.NewSelect().TableExpr("global_field").
Join("NATURAL JOIN wksp_field")

err := tx.NewSelect().ColumnExpr("coalesce(globconf, wkspconf)").
With("global_field", globalField).
With("wksp_field", wkspField).
Table("both").With("both", both).
Scan(ctx, &confBytes)
if err != nil {
Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes)
if err == nil && len(globalBytes) > 0 {
confBytes = globalBytes
}
if err != nil && err != sql.ErrNoRows {
return err
}
return nil

var wkspBytes []byte
err = tx.NewSelect().Table("task_config_policies").
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id = ?", wkspID).
Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes)
if err == nil && len(globalBytes) == 0 {
confBytes = wkspBytes
}
return err
})
if err == sql.ErrNoRows {
if err == sql.ErrNoRows || len(confBytes) == 0 {
return nil, nil
}
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,8 @@ func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *mo
func TestGetEnforcedConfig(t *testing.T) {
ctx := context.Background()
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
pgDB, _ := db.MustResolveNewPostgresDatabase(t)
// defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)

user := db.RequireMockUser(t, pgDB)
Expand Down Expand Up @@ -710,10 +710,12 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

checkpointStorage, err := GetEnforcedConfig[expconf.CheckpointStorageConfig](ctx, &w.ID,
checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID,
"invariant_config", "'checkpoint_storage'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, checkpointStorage)

// global config enforced?
require.Equal(t, expconf.CheckpointStorageConfigV0{
RawSharedFSConfig: &expconf.SharedFSConfigV0{
RawHostPath: ptrs.Ptr("global_host_path"),
Expand Down Expand Up @@ -749,10 +751,12 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "invariant_config",
maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)

// workspace config enforced?
require.Equal(t, 15, *maxSlots)
})

Expand Down Expand Up @@ -788,10 +792,12 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints",
maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)

// global constraint enforced?
require.Equal(t, 25, *maxSlots)
})

Expand Down Expand Up @@ -823,10 +829,54 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetEnforcedConfig[int](ctx, &w.ID, "constraints",
priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'priority_limit'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)
require.Equal(t, 40, *maxSlots)
require.NotNil(t, priority)

// global constraint enforced?
require.Equal(t, 40, *priority)
})

t.Run("priority constraints wksp", func(t *testing.T) {
// delete global config policies
err = DeleteConfigPolicies(ctx, nil, model.ExperimentType)
require.NoError(t, err)

err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkspaceID: &w.ID,
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Constraints: &wkspConstraints,
})
require.NoError(t, err)

priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'priority_limit'", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, priority)

// workspace constraint enforced?
require.Equal(t, 50, *priority)
})

t.Run("field not set in config", func(t *testing.T) {
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config",
"'max_restarts'", model.ExperimentType)
require.NoError(t, err)
require.Nil(t, maxRestarts)
})

t.Run("nonexistent constraints field", func(t *testing.T) {
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'max_restarts'", model.ExperimentType)
require.NoError(t, err)
require.Nil(t, maxRestarts)
})

t.Run("invalid policy type", func(t *testing.T) {
_, err := GetConfigPolicyField[int](ctx, &w.ID, "bad policy",
"'debug'", model.ExperimentType)
require.ErrorContains(t, err, invalidPolicyTypeErr)
})
}
41 changes: 41 additions & 0 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package configpolicy

import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
Expand All @@ -27,6 +28,9 @@ const (
InvalidNTSCConfigPolicyErr = "invalid ntsc config policy"
// NotSupportedConfigPolicyErr is the error reported when admins attempt to set NTSC invariant config.
NotSupportedConfigPolicyErr = "not supported"
// SlotsReqTooHighErr is the error reported when the requested slots violates the max slots
// constraint.
SlotsReqTooHighErr = "requested slots is violates max slots constraint"
)

// ConfigPolicyWarning logs a warning for the configuration policy component.
Expand Down Expand Up @@ -298,3 +302,40 @@ func configPolicyOverlap(config1, config2 interface{}) {
}
}
}

// CanSetMaxSlots returns true if the slots requested don't violate a constraint. It returns the
// enforced max slots for the workspace if that's set as an invariant config, and returns the
// requested max slots otherwise. Returns an error when max slots is not set as an invariant config
// and the requested max slots violates the constriant.
func CanSetMaxSlots(slotsReq *int, wkspID int) (bool, *int, error) {
if slotsReq == nil {
return true, slotsReq, nil
}
enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
return false, nil, err
}

if enforcedMaxSlots != nil {
return true, enforcedMaxSlots, nil
}

maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
return false, nil, err
}

var canSetReqSlots bool
if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit {
canSetReqSlots = true
}
if !canSetReqSlots {
return false, nil, fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit)
}

return true, slotsReq, nil
}
Loading

0 comments on commit 70b0d5e

Please sign in to comment.