Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept JSON policy object in cluster policy commands #557

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions databricks_cli/cluster_policies/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from json import dumps as json_dumps

from databricks_cli.sdk import PolicyService


class ClusterPolicyApi(object):
def __init__(self, api_client):
self.client = PolicyService(api_client)

@staticmethod
def format_policy_for_api(policy):
if isinstance(policy.get("definition"), dict):
policy["definition"] = json_dumps(policy["definition"])
return policy

def create_cluster_policy(self, json):
return self.client.client.perform_query('POST', '/policies/clusters/create', data=json)
return self.client.client.perform_query(
"POST",
"/policies/clusters/create",
data=ClusterPolicyApi.format_policy_for_api(json),
)

def edit_cluster_policy(self, json):
return self.client.client.perform_query('POST', '/policies/clusters/edit', data=json)
return self.client.client.perform_query(
"POST",
"/policies/clusters/edit",
data=ClusterPolicyApi.format_policy_for_api(json),
)

def delete_cluster_policy(self, policy_id):
return self.client.delete_policy(policy_id)
Expand Down
62 changes: 62 additions & 0 deletions tests/cluster_policies/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Databricks CLI
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"), except
# that the use of services to which certain application programming
# interfaces (each, an "API") connect requires that the user first obtain
# a license for the use of the APIs from Databricks, Inc. ("Databricks"),
# by creating an account at www.databricks.com and agreeing to either (a)
# the Community Edition Terms of Service, (b) the Databricks Terms of
# Service, or (c) another written agreement between Licensee and Databricks
# for the use of the APIs.
#
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mock
import pytest

from databricks_cli.cluster_policies.api import ClusterPolicyApi


@pytest.mark.parametrize(
"policy, expected",
[
({"definition": "foo"}, {"definition": "foo"}),
({"definition": {"foo": "bar"}}, {"definition": '{"foo": "bar"}'}),
],
)
def test_format_policy_for_api(policy, expected):
result = ClusterPolicyApi.format_policy_for_api(policy)
assert result == expected


@pytest.mark.parametrize(
"fct_name, method, action",
[
("create_cluster_policy", "POST", "create"),
("edit_cluster_policy", "POST", "edit"),
],
)
@mock.patch(
"databricks_cli.cluster_policies.api.ClusterPolicyApi.format_policy_for_api"
)
def test_create_and_edit_cluster_policy(
mock_format_policy_for_api, fct_name, method, action, fixture_cluster_policies_api
):
mock_policy = mock.Mock()
getattr(fixture_cluster_policies_api, fct_name)(mock_policy)
mock_format_policy_for_api.assert_called_once_with(mock_policy)
fixture_cluster_policies_api.client.client.perform_query.assert_called_once_with(
method,
"/policies/clusters/{}".format(action),
data=mock_format_policy_for_api.return_value,
)
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
# limitations under the License.
import shutil
import tempfile

import mock
import pytest

import databricks_cli.configure.provider as provider
from databricks_cli.cluster_policies.api import ClusterPolicyApi


@pytest.fixture(autouse=True)
Expand All @@ -33,3 +36,12 @@ def mock_conf_dir():
provider._home = path
yield
shutil.rmtree(path)


@pytest.fixture()
def fixture_cluster_policies_api():
with mock.patch(
"databricks_cli.cluster_policies.api.PolicyService"
) as service_mock:
service_mock.return_value = mock.MagicMock()
yield ClusterPolicyApi(None)