Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Commit

Permalink
Add linting to deployment tools (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmc-msft authored Nov 20, 2020
1 parent 9e2a61f commit 3ddb756
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 76 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ jobs:
with:
name: build-artifacts
path: artifacts
- uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Lint
shell: bash
run: |
set -ex
cd src/deployment
pip install mypy isort black
mypy .
isort --profile black . --check
black . --check
- name: Package Onefuzz
run: |
set -ex
Expand Down
6 changes: 3 additions & 3 deletions src/deployment/data_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# Licensed under the MIT License.

import argparse
from uuid import UUID
import json
from typing import Callable, Dict, List
from uuid import UUID

from azure.common.client_factory import get_client_from_cli_profile
from azure.cosmosdb.table.tablebatch import TableBatch
from azure.cosmosdb.table.tableservice import TableService
from azure.mgmt.storage import StorageManagementClient
from azure.common.client_factory import get_client_from_cli_profile


def migrate_task_os(table_service: TableService) -> None:
Expand Down Expand Up @@ -84,7 +84,7 @@ def migrate(table_service: TableService, migration_names: List[str]) -> None:
print("migration '%s' applied" % name)


def main():
def main() -> None:
formatter = argparse.ArgumentDefaultsHelpFormatter
parser = argparse.ArgumentParser(formatter_class=formatter)
parser.add_argument("resource_group")
Expand Down
129 changes: 68 additions & 61 deletions src/deployment/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import uuid
import zipfile
from datetime import datetime, timedelta
from typing import Optional
from typing import Dict, List, Optional, Tuple, Union, cast
from uuid import UUID

from azure.common.client_factory import get_client_from_cli_profile
from azure.common.credentials import get_cli_profile
Expand Down Expand Up @@ -62,11 +63,11 @@

from data_migration import migrate
from registration import (
OnefuzzAppRole,
add_application_password,
assign_scaleset_role,
authorize_application,
get_application,
OnefuzzAppRole,
register_application,
update_pool_registration,
)
Expand Down Expand Up @@ -94,27 +95,28 @@
logger = logging.getLogger("deploy")


def gen_guid():
def gen_guid() -> str:
return str(uuid.uuid4())


class Client:
def __init__(
self,
resource_group,
location,
application_name,
owner,
client_id,
client_secret,
app_zip,
tools,
instance_specific,
third_party,
arm_template,
workbook_data,
create_registration,
migrations,
*,
resource_group: str,
location: str,
application_name: str,
owner: str,
client_id: Optional[str],
client_secret: Optional[str],
app_zip: str,
tools: str,
instance_specific: str,
third_party: str,
arm_template: str,
workbook_data: str,
create_registration: bool,
migrations: List[str],
export_appinsights: bool,
log_service_principal: bool,
upgrade: bool,
Expand All @@ -130,11 +132,11 @@ def __init__(
self.third_party = third_party
self.create_registration = create_registration
self.upgrade = upgrade
self.results = {
self.results: Dict = {
"client_id": client_id,
"client_secret": client_secret,
}
self.cli_config = {
self.cli_config: Dict[str, Union[str, UUID]] = {
"client_id": ONEFUZZ_CLI_APP,
"authority": ONEFUZZ_CLI_AUTHORITY,
}
Expand All @@ -161,22 +163,22 @@ def __init__(
with open(workbook_data) as f:
self.workbook_data = json.load(f)

def get_subscription_id(self):
def get_subscription_id(self) -> str:
profile = get_cli_profile()
return profile.get_subscription_id()
return cast(str, profile.get_subscription_id())

def get_location_display_name(self):
def get_location_display_name(self) -> str:
location_client = get_client_from_cli_profile(SubscriptionClient)
locations = location_client.subscriptions.list_locations(
self.get_subscription_id()
)
for location in locations:
if location.name == self.location:
return location.display_name
return cast(str, location.display_name)

raise Exception("unknown location: %s", self.location)

def check_region(self):
def check_region(self) -> None:
# At the moment, this only checks are the specified providers available
# in the selected region

Expand Down Expand Up @@ -223,7 +225,7 @@ def check_region(self):
print("\n".join(["* " + x for x in unsupported]))
sys.exit(1)

def create_password(self, object_id):
def create_password(self, object_id: UUID) -> Tuple[str, str]:
# Work-around the race condition where the app is created but passwords cannot
# be created yet.
count = 0
Expand All @@ -238,7 +240,7 @@ def create_password(self, object_id):
if count > timeout_seconds / wait:
raise Exception("creating password failed, trying again")

def setup_rbac(self):
def setup_rbac(self) -> None:
"""
Setup the client application for the OneFuzz instance.
Expand Down Expand Up @@ -281,6 +283,8 @@ def setup_rbac(self):
),
]

app: Optional[Application] = None

if not existing:
logger.info("creating Application registration")
url = "https://%s.azurewebsites.net" % self.application_name
Expand Down Expand Up @@ -311,7 +315,7 @@ def setup_rbac(self):
)
client.service_principals.create(service_principal_params)
else:
app: Application = existing[0]
app = existing[0]
existing_role_values = [app_role.value for app_role in app.app_roles]
has_missing_roles = any(
[role.value not in existing_role_values for role in app_roles]
Expand Down Expand Up @@ -365,7 +369,7 @@ def setup_rbac(self):
else:
logger.debug("client_id: %s client_secret: %s", app.app_id, password)

def deploy_template(self):
def deploy_template(self) -> None:
logger.info("deploying arm template: %s", self.arm_template)

with open(self.arm_template, "r") as template_handle:
Expand Down Expand Up @@ -403,7 +407,7 @@ def deploy_template(self):
sys.exit(1)
self.results["deploy"] = result.properties.outputs

def assign_scaleset_identity_role(self):
def assign_scaleset_identity_role(self) -> None:
if self.upgrade:
logger.info("Upgrading: skipping assignment of the managed identity role")
return
Expand All @@ -413,14 +417,14 @@ def assign_scaleset_identity_role(self):
self.results["deploy"]["scaleset-identity"]["value"],
)

def apply_migrations(self):
def apply_migrations(self) -> None:
self.results["deploy"]["func-storage"]["value"]
name = self.results["deploy"]["func-name"]["value"]
key = self.results["deploy"]["func-key"]["value"]
table_service = TableService(account_name=name, account_key=key)
migrate(table_service, self.migrations)

def create_queues(self):
def create_queues(self) -> None:
logger.info("creating eventgrid destination queue")

name = self.results["deploy"]["func-name"]["value"]
Expand All @@ -443,7 +447,7 @@ def create_queues(self):
except ResourceExistsError:
pass

def create_eventgrid(self):
def create_eventgrid(self) -> None:
logger.info("creating eventgrid subscription")
src_resource_id = self.results["deploy"]["fuzz-storage"]["value"]
dst_resource_id = self.results["deploy"]["func-storage"]["value"]
Expand Down Expand Up @@ -474,7 +478,7 @@ def create_eventgrid(self):
% json.dumps(result.as_dict(), indent=4, sort_keys=True),
)

def add_instance_id(self):
def add_instance_id(self) -> None:
logger.info("setting instance_id log export")

container_name = "base-config"
Expand All @@ -497,7 +501,7 @@ def add_instance_id(self):

logger.info("instance_id: %s", instance_id)

def add_log_export(self):
def add_log_export(self) -> None:
if not self.export_appinsights:
logger.info("not exporting appinsights")
return
Expand Down Expand Up @@ -561,7 +565,7 @@ def add_log_export(self):
self.resource_group, self.application_name, req
)

def upload_tools(self):
def upload_tools(self) -> None:
logger.info("uploading tools from %s", self.tools)
account_name = self.results["deploy"]["func-name"]["value"]
key = self.results["deploy"]["func-key"]["value"]
Expand All @@ -587,7 +591,7 @@ def upload_tools(self):
[self.azcopy, "sync", self.tools, url, "--delete-destination", "true"]
)

def upload_instance_setup(self):
def upload_instance_setup(self) -> None:
logger.info("uploading instance-specific-setup from %s", self.instance_specific)
account_name = self.results["deploy"]["func-name"]["value"]
key = self.results["deploy"]["func-key"]["value"]
Expand Down Expand Up @@ -622,7 +626,7 @@ def upload_instance_setup(self):
]
)

def upload_third_party(self):
def upload_third_party(self) -> None:
logger.info("uploading third-party tools from %s", self.third_party)
account_name = self.results["deploy"]["fuzz-name"]["value"]
key = self.results["deploy"]["fuzz-key"]["value"]
Expand Down Expand Up @@ -654,18 +658,21 @@ def upload_third_party(self):
[self.azcopy, "sync", path, url, "--delete-destination", "true"]
)

def deploy_app(self):
def deploy_app(self) -> None:
logger.info("deploying function app %s", self.app_zip)
with tempfile.TemporaryDirectory() as tmpdirname:
with zipfile.ZipFile(self.app_zip, "r") as zip_ref:
func = shutil.which("func")
assert func is not None

zip_ref.extractall(tmpdirname)
error: Optional[subprocess.CalledProcessError] = None
max_tries = 5
for i in range(max_tries):
try:
subprocess.check_output(
[
shutil.which("func"),
func,
"azure",
"functionapp",
"publish",
Expand All @@ -688,12 +695,12 @@ def deploy_app(self):
if error is not None:
raise error

def update_registration(self):
def update_registration(self) -> None:
if not self.create_registration:
return
update_pool_registration(self.application_name)

def done(self):
def done(self) -> None:
logger.info(TELEMETRY_NOTICE)
client_secret_arg = (
("--client_secret %s" % self.cli_config["client_secret"])
Expand All @@ -710,19 +717,19 @@ def done(self):
)


def arg_dir(arg):
def arg_dir(arg: str) -> str:
if not os.path.isdir(arg):
raise argparse.ArgumentTypeError("not a directory: %s" % arg)
return arg


def arg_file(arg):
def arg_file(arg: str) -> str:
if not os.path.isfile(arg):
raise argparse.ArgumentTypeError("not a file: %s" % arg)
return arg


def main():
def main() -> None:
states = [
("check_region", Client.check_region),
("rbac", Client.setup_rbac),
Expand Down Expand Up @@ -826,23 +833,23 @@ def main():
sys.exit(1)

client = Client(
args.resource_group,
args.location,
args.application_name,
args.owner,
args.client_id,
args.client_secret,
args.app_zip,
args.tools,
args.instance_specific,
args.third_party,
args.arm_template,
args.workbook_data,
args.create_pool_registration,
args.apply_migrations,
args.export_appinsights,
args.log_service_principal,
args.upgrade,
resource_group=args.resource_group,
location=args.location,
application_name=args.application_name,
owner=args.owner,
client_id=args.client_id,
client_secret=args.client_secret,
app_zip=args.app_zip,
tools=args.tools,
instance_specific=args.instance_specific,
third_party=args.third_party,
arm_template=args.arm_template,
workbook_data=args.workbook_data,
create_registration=args.create_pool_registration,
migrations=args.apply_migrations,
export_appinsights=args.export_appinsights,
log_service_principal=args.log_service_principal,
upgrade=args.upgrade,
)
if args.verbose:
level = logging.DEBUG
Expand Down
Loading

0 comments on commit 3ddb756

Please sign in to comment.