Skip to content

Commit

Permalink
Merge branch 'main' into kanterov/python-locations
Browse files Browse the repository at this point in the history
  • Loading branch information
pietern authored Sep 25, 2024
2 parents 33f4b73 + 3d9decd commit def4744
Show file tree
Hide file tree
Showing 38 changed files with 1,129 additions and 70 deletions.
7 changes: 5 additions & 2 deletions bundle/config/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ type Bundle struct {
// Annotated readonly as this should be set at the target level.
Mode Mode `json:"mode,omitempty" bundle:"readonly"`

// Overrides the compute used for jobs and other supported assets.
ComputeID string `json:"compute_id,omitempty"`
// DEPRECATED: Overrides the compute used for jobs and other supported assets.
ComputeId string `json:"compute_id,omitempty"`

// Overrides the cluster used for jobs and other supported assets.
ClusterId string `json:"cluster_id,omitempty"`

// Deployment section specifies deployment related configuration for bundle
Deployment Deployment `json:"deployment,omitempty"`
Expand Down
15 changes: 15 additions & 0 deletions bundle/config/mutator/apply_presets.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,21 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos
// the Databricks UI and via the SQL API.
}

// Clusters: Prefix, Tags
for _, c := range r.Clusters {
c.ClusterName = prefix + c.ClusterName
if c.CustomTags == nil {
c.CustomTags = make(map[string]string)
}
for _, tag := range tags {
normalisedKey := b.Tagging.NormalizeKey(tag.Key)
normalisedValue := b.Tagging.NormalizeValue(tag.Value)
if _, ok := c.CustomTags[normalisedKey]; !ok {
c.CustomTags[normalisedKey] = normalisedValue
}
}
}

return nil
}

Expand Down
87 changes: 87 additions & 0 deletions bundle/config/mutator/compute_id_compat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package mutator

import (
"context"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
)

type computeIdToClusterId struct{}

func ComputeIdToClusterId() bundle.Mutator {
return &computeIdToClusterId{}
}

func (m *computeIdToClusterId) Name() string {
return "ComputeIdToClusterId"
}

func (m *computeIdToClusterId) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
var diags diag.Diagnostics

// The "compute_id" key is set; rewrite it to "cluster_id".
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
v, d := rewriteComputeIdToClusterId(v, dyn.NewPath(dyn.Key("bundle")))
diags = diags.Extend(d)

// Check if the "compute_id" key is set in any target overrides.
return dyn.MapByPattern(v, dyn.NewPattern(dyn.Key("targets"), dyn.AnyKey()), func(p dyn.Path, v dyn.Value) (dyn.Value, error) {
v, d := rewriteComputeIdToClusterId(v, dyn.Path{})
diags = diags.Extend(d)
return v, nil
})
})

diags = diags.Extend(diag.FromErr(err))
return diags
}

func rewriteComputeIdToClusterId(v dyn.Value, p dyn.Path) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
computeIdPath := p.Append(dyn.Key("compute_id"))
computeId, err := dyn.GetByPath(v, computeIdPath)

// If the "compute_id" key is not set, we don't need to do anything.
if err != nil {
return v, nil
}

if computeId.Kind() == dyn.KindInvalid {
return v, nil
}

diags = diags.Append(diag.Diagnostic{
Severity: diag.Warning,
Summary: "compute_id is deprecated, please use cluster_id instead",
Locations: computeId.Locations(),
Paths: []dyn.Path{computeIdPath},
})

clusterIdPath := p.Append(dyn.Key("cluster_id"))
nv, err := dyn.SetByPath(v, clusterIdPath, computeId)
if err != nil {
return dyn.InvalidValue, diag.FromErr(err)
}
// Drop the "compute_id" key.
vout, err := dyn.Walk(nv, func(p dyn.Path, v dyn.Value) (dyn.Value, error) {
switch len(p) {
case 0:
return v, nil
case 1:
if p[0] == dyn.Key("compute_id") {
return v, dyn.ErrDrop
}
return v, nil
case 2:
if p[1] == dyn.Key("compute_id") {
return v, dyn.ErrDrop
}
}
return v, dyn.ErrSkip
})

diags = diags.Extend(diag.FromErr(err))
return vout, diags
}
57 changes: 57 additions & 0 deletions bundle/config/mutator/compute_id_compate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package mutator_test

import (
"context"
"testing"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/libs/diag"
"github.com/stretchr/testify/assert"
)

func TestComputeIdToClusterId(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Bundle: config.Bundle{
ComputeId: "compute-id",
},
},
}

diags := bundle.Apply(context.Background(), b, mutator.ComputeIdToClusterId())
assert.NoError(t, diags.Error())
assert.Equal(t, "compute-id", b.Config.Bundle.ClusterId)
assert.Empty(t, b.Config.Bundle.ComputeId)

assert.Len(t, diags, 1)
assert.Equal(t, "compute_id is deprecated, please use cluster_id instead", diags[0].Summary)
assert.Equal(t, diag.Warning, diags[0].Severity)
}

func TestComputeIdToClusterIdInTargetOverride(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Targets: map[string]*config.Target{
"dev": {
ComputeId: "compute-id-dev",
},
},
},
}

diags := bundle.Apply(context.Background(), b, mutator.ComputeIdToClusterId())
assert.NoError(t, diags.Error())
assert.Empty(t, b.Config.Targets["dev"].ComputeId)

diags = diags.Extend(bundle.Apply(context.Background(), b, mutator.SelectTarget("dev")))
assert.NoError(t, diags.Error())

assert.Equal(t, "compute-id-dev", b.Config.Bundle.ClusterId)
assert.Empty(t, b.Config.Bundle.ComputeId)

assert.Len(t, diags, 1)
assert.Equal(t, "compute_id is deprecated, please use cluster_id instead", diags[0].Summary)
assert.Equal(t, diag.Warning, diags[0].Severity)
}
1 change: 1 addition & 0 deletions bundle/config/mutator/mutator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func DefaultMutators() []bundle.Mutator {
VerifyCliVersion(),

EnvironmentsToTargets(),
ComputeIdToClusterId(),
InitializeVariables(),
DefineDefaultTarget(),
LoadGitDetails(),
Expand Down
8 changes: 4 additions & 4 deletions bundle/config/mutator/override_compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ func overrideJobCompute(j *resources.Job, compute string) {

func (m *overrideCompute) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
if b.Config.Bundle.Mode != config.Development {
if b.Config.Bundle.ComputeID != "" {
if b.Config.Bundle.ClusterId != "" {
return diag.Errorf("cannot override compute for an target that does not use 'mode: development'")
}
return nil
}
if v := env.Get(ctx, "DATABRICKS_CLUSTER_ID"); v != "" {
b.Config.Bundle.ComputeID = v
b.Config.Bundle.ClusterId = v
}

if b.Config.Bundle.ComputeID == "" {
if b.Config.Bundle.ClusterId == "" {
return nil
}

r := b.Config.Resources
for i := range r.Jobs {
overrideJobCompute(r.Jobs[i], b.Config.Bundle.ComputeID)
overrideJobCompute(r.Jobs[i], b.Config.Bundle.ClusterId)
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions bundle/config/mutator/override_compute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestOverrideDevelopment(t *testing.T) {
Config: config.Root{
Bundle: config.Bundle{
Mode: config.Development,
ComputeID: "newClusterID",
ClusterId: "newClusterID",
},
Resources: config.Resources{
Jobs: map[string]*resources.Job{
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestOverrideProduction(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Bundle: config.Bundle{
ComputeID: "newClusterID",
ClusterId: "newClusterID",
},
Resources: config.Resources{
Jobs: map[string]*resources.Job{
Expand Down
10 changes: 10 additions & 0 deletions bundle/config/mutator/process_target_mode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/databricks/cli/libs/tags"
sdkconfig "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
Expand Down Expand Up @@ -119,6 +120,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
Schemas: map[string]*resources.Schema{
"schema1": {CreateSchema: &catalog.CreateSchema{Name: "schema1"}},
},
Clusters: map[string]*resources.Cluster{
"cluster1": {ClusterSpec: &compute.ClusterSpec{ClusterName: "cluster1", SparkVersion: "13.2.x", NumWorkers: 1}},
},
},
},
// Use AWS implementation for testing.
Expand Down Expand Up @@ -177,6 +181,9 @@ func TestProcessTargetModeDevelopment(t *testing.T) {

// Schema 1
assert.Equal(t, "dev_lennart_schema1", b.Config.Resources.Schemas["schema1"].Name)

// Clusters
assert.Equal(t, "[dev lennart] cluster1", b.Config.Resources.Clusters["cluster1"].ClusterName)
}

func TestProcessTargetModeDevelopmentTagNormalizationForAws(t *testing.T) {
Expand Down Expand Up @@ -281,6 +288,7 @@ func TestProcessTargetModeDefault(t *testing.T) {
assert.Equal(t, "servingendpoint1", b.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "registeredmodel1", b.Config.Resources.RegisteredModels["registeredmodel1"].Name)
assert.Equal(t, "qualityMonitor1", b.Config.Resources.QualityMonitors["qualityMonitor1"].TableName)
assert.Equal(t, "cluster1", b.Config.Resources.Clusters["cluster1"].ClusterName)
}

func TestProcessTargetModeProduction(t *testing.T) {
Expand Down Expand Up @@ -312,6 +320,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
b.Config.Resources.Experiments["experiment2"].Permissions = permissions
b.Config.Resources.Models["model1"].Permissions = permissions
b.Config.Resources.ModelServingEndpoints["servingendpoint1"].Permissions = permissions
b.Config.Resources.Clusters["cluster1"].Permissions = permissions

diags = validateProductionMode(context.Background(), b, false)
require.NoError(t, diags.Error())
Expand All @@ -322,6 +331,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
assert.Equal(t, "servingendpoint1", b.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "registeredmodel1", b.Config.Resources.RegisteredModels["registeredmodel1"].Name)
assert.Equal(t, "qualityMonitor1", b.Config.Resources.QualityMonitors["qualityMonitor1"].TableName)
assert.Equal(t, "cluster1", b.Config.Resources.Clusters["cluster1"].ClusterName)
}

func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions bundle/config/mutator/run_as_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func allResourceTypes(t *testing.T) []string {
// the dyn library gives us the correct list of all resources supported. Please
// also update this check when adding a new resource
require.Equal(t, []string{
"clusters",
"experiments",
"jobs",
"model_serving_endpoints",
Expand Down Expand Up @@ -133,6 +134,7 @@ func TestRunAsErrorForUnsupportedResources(t *testing.T) {
// some point in the future. These resources are (implicitly) on the deny list, since
// they are not on the allow list below.
allowList := []string{
"clusters",
"jobs",
"models",
"registered_models",
Expand Down
1 change: 1 addition & 0 deletions bundle/config/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Resources struct {
RegisteredModels map[string]*resources.RegisteredModel `json:"registered_models,omitempty"`
QualityMonitors map[string]*resources.QualityMonitor `json:"quality_monitors,omitempty"`
Schemas map[string]*resources.Schema `json:"schemas,omitempty"`
Clusters map[string]*resources.Cluster `json:"clusters,omitempty"`
}

type ConfigResource interface {
Expand Down
39 changes: 39 additions & 0 deletions bundle/config/resources/clusters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package resources

import (
"context"

"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/marshal"
"github.com/databricks/databricks-sdk-go/service/compute"
)

type Cluster struct {
ID string `json:"id,omitempty" bundle:"readonly"`
Permissions []Permission `json:"permissions,omitempty"`
ModifiedStatus ModifiedStatus `json:"modified_status,omitempty" bundle:"internal"`

*compute.ClusterSpec
}

func (s *Cluster) UnmarshalJSON(b []byte) error {
return marshal.Unmarshal(b, s)
}

func (s Cluster) MarshalJSON() ([]byte, error) {
return marshal.Marshal(s)
}

func (s *Cluster) Exists(ctx context.Context, w *databricks.WorkspaceClient, id string) (bool, error) {
_, err := w.Clusters.GetByClusterId(ctx, id)
if err != nil {
log.Debugf(ctx, "cluster %s does not exist", id)
return false, err
}
return true, nil
}

func (s *Cluster) TerraformResourceName() string {
return "databricks_cluster"
}
6 changes: 3 additions & 3 deletions bundle/config/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,9 @@ func (r *Root) MergeTargetOverrides(name string) error {
}
}

// Merge `compute_id`. This field must be overwritten if set, not merged.
if v := target.Get("compute_id"); v.Kind() != dyn.KindInvalid {
root, err = dyn.SetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("compute_id")), v)
// Merge `cluster_id`. This field must be overwritten if set, not merged.
if v := target.Get("cluster_id"); v.Kind() != dyn.KindInvalid {
root, err = dyn.SetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("cluster_id")), v)
if err != nil {
return err
}
Expand Down
7 changes: 5 additions & 2 deletions bundle/config/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ type Target struct {
// name prefix of deployed resources.
Presets Presets `json:"presets,omitempty"`

// Overrides the compute used for jobs and other supported assets.
ComputeID string `json:"compute_id,omitempty"`
// DEPRECATED: Overrides the compute used for jobs and other supported assets.
ComputeId string `json:"compute_id,omitempty"`

// Overrides the cluster used for jobs and other supported assets.
ClusterId string `json:"cluster_id,omitempty"`

Bundle *Bundle `json:"bundle,omitempty"`

Expand Down
Loading

0 comments on commit def4744

Please sign in to comment.