Skip to content

Commit

Permalink
Updates to Batch authentication and networking (#1517)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamuri authored Nov 19, 2024
1 parent 28643ea commit b96bafa
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions src/tlo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import dateutil.parser
import pandas as pd
from azure import batch
from azure.batch import batch_auth
from azure.batch import models as batch_models
from azure.batch.models import BatchErrorException
from azure.common.credentials import ServicePrincipalCredentials
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
Expand Down Expand Up @@ -132,9 +132,7 @@ def batch_submit(ctx, scenario_file, asserts_on, more_memory, keep_pool_alive, i
azure_directory = f"{config['DEFAULT']['USERNAME']}/{job_id}"

batch_client = get_batch_client(
config["BATCH"]["NAME"],
config["BATCH"]["KEY"],
config["BATCH"]["URL"]
config["BATCH"]["CLIENT_ID"], config["BATCH"]["SECRET"], config["AZURE"]["TENANT_ID"], config["BATCH"]["URL"]
)

create_file_share(
Expand Down Expand Up @@ -247,8 +245,16 @@ def batch_submit(ctx, scenario_file, asserts_on, more_memory, keep_pool_alive, i

try:
# Create the job that will run the tasks.
create_job(batch_client, vm_size, pool_node_count, job_id,
container_conf, [mount_configuration], keep_pool_alive)
create_job(
batch_client,
vm_size,
pool_node_count,
job_id,
container_conf,
[mount_configuration],
keep_pool_alive,
config["BATCH"]["SUBNET_ID"],
)

# Add the tasks to the job.
add_tasks(batch_client, user_identity, job_id, image_name,
Expand Down Expand Up @@ -297,9 +303,7 @@ def batch_terminate(ctx, job_id):
return

batch_client = get_batch_client(
config["BATCH"]["NAME"],
config["BATCH"]["KEY"],
config["BATCH"]["URL"]
config["BATCH"]["CLIENT_ID"], config["BATCH"]["SECRET"], config["AZURE"]["TENANT_ID"], config["BATCH"]["URL"]
)

# check the job is running
Expand Down Expand Up @@ -331,10 +335,9 @@ def batch_job(ctx, job_id, raw, show_tasks):
print(">Querying batch system\r", end="")
config = load_config(ctx.obj['config_file'])
batch_client = get_batch_client(
config["BATCH"]["NAME"],
config["BATCH"]["KEY"],
config["BATCH"]["URL"]
config["BATCH"]["CLIENT_ID"], config["BATCH"]["SECRET"], config["AZURE"]["TENANT_ID"], config["BATCH"]["URL"]
)

tasks = None

try:
Expand Down Expand Up @@ -402,9 +405,7 @@ def batch_list(ctx, status, n, find, username):
username = config["DEFAULT"]["USERNAME"]

batch_client = get_batch_client(
config["BATCH"]["NAME"],
config["BATCH"]["KEY"],
config["BATCH"]["URL"]
config["BATCH"]["CLIENT_ID"], config["BATCH"]["SECRET"], config["AZURE"]["TENANT_ID"], config["BATCH"]["URL"]
)

# create client to connect to file share
Expand Down Expand Up @@ -581,9 +582,12 @@ def load_server_config(kv_uri, tenant_id) -> Dict[str, Dict]:
return {"STORAGE": storage_config, "BATCH": batch_config, "REGISTRY": registry_config}


def get_batch_client(name, key, url):
def get_batch_client(client_id, secret, tenant_id, url):
"""Create a Batch service client"""
credentials = batch_auth.SharedKeyCredentials(name, key)
resource = "https://batch.core.windows.net/"

credentials = ServicePrincipalCredentials(client_id=client_id, secret=secret, tenant=tenant_id, resource=resource)

batch_client = batch.BatchServiceClient(credentials, batch_url=url)
return batch_client

Expand Down Expand Up @@ -697,10 +701,19 @@ def upload_local_file(connection_string, local_file_path, share_name, dest_file_
print("ResourceNotFoundError:", ex.message)


def create_job(batch_service_client, vm_size, pool_node_count, job_id,
container_conf, mount_configuration, keep_pool_alive):
def create_job(
batch_service_client,
vm_size,
pool_node_count,
job_id,
container_conf,
mount_configuration,
keep_pool_alive,
subnet_id,
):
"""Creates a job with the specified ID, associated with the specified pool.
:param subnet_id:
:param batch_service_client: A Batch service client.
:type batch_service_client: `azure.batch.BatchServiceClient`
:param str vm_size: Type of virtual machine to use as pool.
Expand Down Expand Up @@ -740,13 +753,20 @@ def create_job(batch_service_client, vm_size, pool_node_count, job_id,
$NodeDeallocationOption = taskcompletion;
"""

network_configuration = batch_models.NetworkConfiguration(
subnet_id=subnet_id,
public_ip_address_configuration=batch_models.PublicIPAddressConfiguration(provision="noPublicIPAddresses"),
)

pool = batch_models.PoolSpecification(
virtual_machine_configuration=virtual_machine_configuration,
vm_size=vm_size,
mount_configuration=mount_configuration,
task_slots_per_node=1,
enable_auto_scale=True,
auto_scale_formula=auto_scale_formula,
network_configuration=network_configuration,
target_node_communication_mode="simplified",
)

auto_pool_specification = batch_models.AutoPoolSpecification(
Expand Down

0 comments on commit b96bafa

Please sign in to comment.