diff --git a/api/src/main/java/com/epam/pipeline/manager/cloud/CloudInstanceService.java b/api/src/main/java/com/epam/pipeline/manager/cloud/CloudInstanceService.java index 87c2aeaa2c..28e0eca984 100644 --- a/api/src/main/java/com/epam/pipeline/manager/cloud/CloudInstanceService.java +++ b/api/src/main/java/com/epam/pipeline/manager/cloud/CloudInstanceService.java @@ -24,6 +24,7 @@ import com.epam.pipeline.entity.pipeline.DiskAttachRequest; import com.epam.pipeline.entity.pipeline.RunInstance; import com.epam.pipeline.entity.region.AbstractCloudRegion; +import com.epam.pipeline.manager.cluster.KubernetesConstants; import com.epam.pipeline.manager.cluster.autoscale.AutoscalerServiceImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +33,7 @@ import java.time.Duration; import java.time.LocalDateTime; import java.time.format.DateTimeParseException; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -43,6 +45,10 @@ public interface CloudInstanceService int TIME_DELIMITER = 60; int TIME_TO_SHUT_DOWN_NODE = 1; + default Map getPoolLabels(final NodePool pool) { + return Collections.singletonMap(KubernetesConstants.NODE_POOL_ID_LABEL, String.valueOf(pool.getId())); + } + /** * Creates new instance using specified cloud and adds it to cluster * @param runId diff --git a/api/src/main/java/com/epam/pipeline/manager/cloud/aws/AWSInstanceService.java b/api/src/main/java/com/epam/pipeline/manager/cloud/aws/AWSInstanceService.java index 3c72c28d1e..87db3adeeb 100644 --- a/api/src/main/java/com/epam/pipeline/manager/cloud/aws/AWSInstanceService.java +++ b/api/src/main/java/com/epam/pipeline/manager/cloud/aws/AWSInstanceService.java @@ -36,7 +36,6 @@ import com.epam.pipeline.manager.cloud.commands.ClusterCommandService; import com.epam.pipeline.manager.cloud.commands.NodeUpCommand; import com.epam.pipeline.manager.cluster.InstanceOfferManager; -import com.epam.pipeline.manager.cluster.KubernetesConstants; import com.epam.pipeline.manager.execution.SystemParams; import com.epam.pipeline.manager.preference.PreferenceManager; import com.epam.pipeline.manager.preference.SystemPreferences; @@ -109,9 +108,7 @@ public RunInstance scaleUpPoolNode(final AwsRegion region, final String nodeIdLabel, final NodePool node) { final RunInstance instance = node.toRunInstance(); - final Map labels = Collections.singletonMap( - KubernetesConstants.NODE_POOL_ID_LABEL, String.valueOf(node.getId())); - final String command = buildNodeUpCommand(region, nodeIdLabel, instance, labels); + final String command = buildNodeUpCommand(region, nodeIdLabel, instance, getPoolLabels(node)); return instanceService.runNodeUpScript(cmdExecutor, null, instance, command, buildScriptEnvVars()); } diff --git a/api/src/main/java/com/epam/pipeline/manager/cloud/azure/AzureInstanceService.java b/api/src/main/java/com/epam/pipeline/manager/cloud/azure/AzureInstanceService.java index de6d531ddc..c82f4cf870 100644 --- a/api/src/main/java/com/epam/pipeline/manager/cloud/azure/AzureInstanceService.java +++ b/api/src/main/java/com/epam/pipeline/manager/cloud/azure/AzureInstanceService.java @@ -47,6 +47,7 @@ import org.springframework.stereotype.Service; import java.time.LocalDateTime; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -104,21 +105,22 @@ public AzureInstanceService(final CommonCloudInstanceService instanceService, public RunInstance scaleUpNode(final AzureRegion region, final Long runId, final RunInstance instance) { - final String command = buildNodeUpCommand(region, runId, instance); - final Map envVars = buildScriptAzureEnvVars(region); - return instanceService.runNodeUpScript(cmdExecutor, runId, instance, command, envVars); + final String command = buildNodeUpCommand(region, String.valueOf(runId), instance, Collections.emptyMap()); + return instanceService.runNodeUpScript(cmdExecutor, runId, instance, command, buildScriptAzureEnvVars(region)); } @Override public RunInstance scaleUpPoolNode(final AzureRegion region, final String nodeId, - final NodePool node) { - throw new UnsupportedOperationException(); + final NodePool nodePool) { + final RunInstance instance = nodePool.toRunInstance(); + final String command = buildNodeUpCommand(region, nodeId, instance, getPoolLabels(nodePool)); + return instanceService.runNodeUpScript(cmdExecutor, null, instance, command, buildScriptAzureEnvVars(region)); } @Override public void scaleDownNode(final AzureRegion region, final Long runId) { - final String command = buildNodeDownCommand(runId); + final String command = buildNodeDownCommand(String.valueOf(runId)); final Map envVars = buildScriptAzureEnvVars(region); CompletableFuture.runAsync(() -> instanceService.runNodeDownScript(cmdExecutor, command, envVars), executorService.getExecutorService()); @@ -126,7 +128,10 @@ public void scaleDownNode(final AzureRegion region, final Long runId) { @Override public void scaleDownPoolNode(final AzureRegion region, final String nodeLabel) { - throw new UnsupportedOperationException(); + final String command = buildNodeDownCommand(nodeLabel); + final Map envVars = buildScriptAzureEnvVars(region); + CompletableFuture.runAsync(() -> instanceService.runNodeDownScript(cmdExecutor, command, envVars), + executorService.getExecutorService()); } @Override @@ -138,10 +143,11 @@ public boolean reassignNode(final AzureRegion region, final Long oldId, final Lo } @Override - public boolean reassignPoolNode(final AzureRegion region, - final String nodeLabel, - final Long newId) { - throw new UnsupportedOperationException(); + public boolean reassignPoolNode(final AzureRegion region, final String nodeLabel, final Long newId) { + final String command = commandService. + buildNodeReassignCommand(nodeReassignScript, nodeLabel, String.valueOf(newId), getProvider().name()); + return instanceService.runNodeReassignScript(cmdExecutor, command, nodeLabel, + String.valueOf(newId), buildScriptAzureEnvVars(region)); } @Override @@ -266,19 +272,22 @@ private Map buildScriptAzureEnvVars(final AzureRegion region) { return envVars; } - private String buildNodeUpCommand(final AzureRegion region, final Long runId, final RunInstance instance) { + private String buildNodeUpCommand(final AzureRegion region, final String nodeLabel, final RunInstance instance, + final Map labels) { final NodeUpCommand.NodeUpCommandBuilder commandBuilder = NodeUpCommand.builder() .executable(AbstractClusterCommand.EXECUTABLE) .script(nodeUpScript) - .runId(String.valueOf(runId)) + .runId(nodeLabel) .sshKey(region.getSshPublicKeyPath()) .instanceImage(instance.getNodeImage()) .instanceType(instance.getNodeType()) .instanceDisk(String.valueOf(instance.getEffectiveNodeDisk())) .kubeIP(kubeMasterIP) .kubeToken(kubeToken) - .region(region.getRegionCode()); + .region(region.getRegionCode()) + .prePulledImages(instance.getPrePulledDockerImages()) + .additionalLabels(labels); final Boolean clusterSpotStrategy = instance.getSpot() == null ? preferenceManager.getPreference(SystemPreferences.CLUSTER_SPOT) @@ -290,11 +299,11 @@ private String buildNodeUpCommand(final AzureRegion region, final Long runId, fi return commandBuilder.build().getCommand(); } - private String buildNodeDownCommand(final Long runId) { + private String buildNodeDownCommand(final String nodeLabel) { return RunIdArgCommand.builder() .executable(AbstractClusterCommand.EXECUTABLE) .script(nodeDownScript) - .runId(String.valueOf(runId)) + .runId(nodeLabel) .build() .getCommand(); } diff --git a/api/src/main/java/com/epam/pipeline/manager/cloud/commands/ClusterCommandService.java b/api/src/main/java/com/epam/pipeline/manager/cloud/commands/ClusterCommandService.java index ec76ec7b0c..54ad7652b5 100644 --- a/api/src/main/java/com/epam/pipeline/manager/cloud/commands/ClusterCommandService.java +++ b/api/src/main/java/com/epam/pipeline/manager/cloud/commands/ClusterCommandService.java @@ -33,14 +33,6 @@ public ClusterCommandService(@Value("${kube.master.ip}") final String kubeMaster this.kubeToken = kubeToken; } - public NodeUpCommand.NodeUpCommandBuilder buildNodeUpCommand(final String nodeUpScript, - final AbstractCloudRegion region, - final Long runId, - final RunInstance instance, - final String cloud) { - return buildNodeUpCommand(nodeUpScript, region, String.valueOf(runId), instance, cloud); - } - public NodeUpCommand.NodeUpCommandBuilder buildNodeUpCommand(final String nodeUpScript, final AbstractCloudRegion region, final String nodeLabel, diff --git a/api/src/main/java/com/epam/pipeline/manager/cloud/gcp/GCPInstanceService.java b/api/src/main/java/com/epam/pipeline/manager/cloud/gcp/GCPInstanceService.java index 7008323978..3869b23bc6 100644 --- a/api/src/main/java/com/epam/pipeline/manager/cloud/gcp/GCPInstanceService.java +++ b/api/src/main/java/com/epam/pipeline/manager/cloud/gcp/GCPInstanceService.java @@ -43,6 +43,7 @@ import java.nio.file.InvalidPathException; import java.nio.file.Paths; import java.time.LocalDateTime; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -88,20 +89,15 @@ public GCPInstanceService(final ClusterCommandService commandService, @Override public RunInstance scaleUpNode(final GCPRegion region, final Long runId, final RunInstance instance) { - final String command = commandService.buildNodeUpCommand(nodeUpScript, region, runId, instance, - getProviderName()).sshKey(region.getSshPublicKeyPath()) - .isSpot(Optional.ofNullable(instance.getSpot()) - .orElse(false)) - .bidPrice(StringUtils.EMPTY) - .build() - .getCommand(); - final Map envVars = buildScriptGCPEnvVars(region); - return instanceService.runNodeUpScript(cmdExecutor, runId, instance, command, envVars); + final String command = buildNodeUpCommand(region, String.valueOf(runId), instance, Collections.emptyMap()); + return instanceService.runNodeUpScript(cmdExecutor, runId, instance, command, buildScriptGCPEnvVars(region)); } @Override public RunInstance scaleUpPoolNode(final GCPRegion region, final String nodeId, final NodePool node) { - throw new UnsupportedOperationException(); + final RunInstance instance = node.toRunInstance(); + final String command = buildNodeUpCommand(region, nodeId, instance, getPoolLabels(node)); + return instanceService.runNodeUpScript(cmdExecutor, null, instance, command, buildScriptGCPEnvVars(region)); } @Override @@ -113,7 +109,8 @@ public void scaleDownNode(final GCPRegion region, final Long runId) { @Override public void scaleDownPoolNode(final GCPRegion region, final String nodeLabel) { - throw new UnsupportedOperationException(); + final String command = commandService.buildNodeDownCommand(nodeDownScript, nodeLabel, getProviderName()); + instanceService.runNodeDownScript(cmdExecutor, command, buildScriptGCPEnvVars(region)); } @Override @@ -126,10 +123,12 @@ public boolean reassignNode(final GCPRegion region, final Long oldId, final Long @Override public boolean reassignPoolNode(final GCPRegion region, final String nodeLabel, final Long newId) { - throw new UnsupportedOperationException(); + final String command = commandService + .buildNodeReassignCommand(nodeReassignScript, nodeLabel, String.valueOf(newId), getProviderName()); + return instanceService.runNodeReassignScript(cmdExecutor, command, nodeLabel, String.valueOf(newId), + buildScriptGCPEnvVars(region)); } - @Override public void terminateNode(final GCPRegion region, final String internalIp, final String nodeName) { final String command = commandService.buildTerminateNodeCommand(nodeTerminateScript, internalIp, @@ -248,6 +247,20 @@ public CloudInstanceState getInstanceState(final GCPRegion region, final String } } + private String buildNodeUpCommand(final GCPRegion region, final String nodeLabel, final RunInstance instance, + final Map labels) { + return commandService + .buildNodeUpCommand(nodeUpScript, region, nodeLabel, instance, getProviderName()) + .sshKey(region.getSshPublicKeyPath()) + .isSpot(Optional.ofNullable(instance.getSpot()) + .orElse(false)) + .bidPrice(StringUtils.EMPTY) + .additionalLabels(labels) + .prePulledImages(instance.getPrePulledDockerImages()) + .build() + .getCommand(); + } + private String getCredentialsFilePath(GCPRegion region) { return StringUtils.isEmpty(region.getAuthFile()) ? System.getenv(GOOGLE_APPLICATION_CREDENTIALS) diff --git a/scripts/autoscaling/azure/nodeup.py b/scripts/autoscaling/azure/nodeup.py index 78a5aeff58..5f0ee8ba27 100644 --- a/scripts/autoscaling/azure/nodeup.py +++ b/scripts/autoscaling/azure/nodeup.py @@ -33,6 +33,7 @@ from azure.mgmt.network import NetworkManagementClient from msrestazure.azure_exceptions import CloudError from pipeline import Logger, TaskStatus, PipelineAPI, pack_script_contents +import jwt VM_NAME_PREFIX = "az-" UUID_LENGHT = 16 @@ -52,6 +53,21 @@ script_path = None +def is_run_id_numerical(run_id): + try: + int(run_id) + return True + except ValueError: + return False + + +def is_api_logging_enabled(): + global api_token + global api_url + global current_run_id + return is_run_id_numerical(current_run_id) and api_url and api_token + + def pipe_log_init(run_id): global api_token global api_url @@ -61,7 +77,7 @@ def pipe_log_init(run_id): api_url = os.environ["API"] api_token = os.environ["API_TOKEN"] - if not api_url or not api_token: + if not is_api_logging_enabled(): logging.basicConfig(filename='nodeup.log', level=logging.INFO, format='%(asctime)s %(message)s') @@ -71,7 +87,7 @@ def pipe_log_warn(message): global script_path global current_run_id - if api_url and api_token: + if is_api_logging_enabled(): Logger.warn('[{}] {}'.format(current_run_id, message), task_name=NODEUP_TASK, run_id=current_run_id, @@ -88,7 +104,7 @@ def pipe_log(message, status=TaskStatus.RUNNING): global script_path global current_run_id - if api_url and api_token: + if is_api_logging_enabled(): Logger.log_task_event(NODEUP_TASK, '[{}] {}'.format(current_run_id, message), run_id=current_run_id, @@ -217,10 +233,11 @@ def get_allowed_instance_image(cloud_region, instance_type, default_image): def run_instance(api_url, api_token, instance_name, instance_type, cloud_region, run_id, ins_hdd, ins_img, ssh_pub_key, user, - ins_type, is_spot, kube_ip, kubeadm_token): + ins_type, is_spot, kube_ip, kubeadm_token, pre_pull_images): ins_key = read_ssh_key(ssh_pub_key) swap_size = get_swap_size(cloud_region, ins_type, is_spot) - user_data_script = get_user_data_script(api_url, api_token, cloud_region, ins_type, ins_img, kube_ip, kubeadm_token, swap_size) + user_data_script = get_user_data_script(api_url, api_token, cloud_region, ins_type, ins_img, kube_ip, kubeadm_token, + swap_size, pre_pull_images) access_config = get_access_config(cloud_region) disable_external_access = False if access_config is not None: @@ -705,7 +722,7 @@ def verify_regnode(ins_id, num_rep, time_rep, api): return ret_namenode -def label_node(nodename, run_id, api, cluster_name, cluster_role, cloud_region): +def label_node(nodename, run_id, api, cluster_name, cluster_role, cloud_region, additional_labels): pipe_log('Assigning instance {} to RunID: {}'.format(nodename, run_id)) obj = { "apiVersion": "v1", @@ -719,6 +736,14 @@ def label_node(nodename, run_id, api, cluster_name, cluster_role, cloud_region): } } + if additional_labels: + for label in additional_labels: + label_parts = label.split("=") + if len(label_parts) == 1: + obj["metadata"]["labels"][label_parts[0]] = None + else: + obj["metadata"]["labels"][label_parts[0]] = label_parts[1] + if cluster_name: obj["metadata"]["labels"]["cp-cluster-name"] = cluster_name if cluster_role: @@ -926,7 +951,21 @@ def get_swap_ratio(swap_params): return None -def get_user_data_script(api_url, api_token, cloud_region, ins_type, ins_img, kube_ip, kubeadm_token, swap_size): +def replace_docker_images(pre_pull_images, user_data_script): + global api_token + payload = jwt.decode(api_token, verify=False) + if 'sub' in payload: + subject = payload['sub'] + user_data_script = user_data_script \ + .replace("@PRE_PULL_DOCKERS@", ",".join(pre_pull_images)) \ + .replace("@API_USER@", subject) + return user_data_script + else: + raise RuntimeError("Pre-pulled docker initialization failed: unable to parse JWT token for docker auth.") + + +def get_user_data_script(api_url, api_token, cloud_region, ins_type, ins_img, kube_ip, kubeadm_token, swap_size, + pre_pull_images): allowed_instance = get_allowed_instance_image(cloud_region, ins_type, ins_img) if allowed_instance and allowed_instance["init_script"]: init_script = open(allowed_instance["init_script"], 'r') @@ -936,6 +975,7 @@ def get_user_data_script(api_url, api_token, cloud_region, ins_type, ins_img, ku init_script.close() user_data_script = replace_proxies(cloud_region, user_data_script) user_data_script = replace_swap(swap_size, user_data_script) + user_data_script = replace_docker_images(pre_pull_images, user_data_script) user_data_script = user_data_script.replace('@DOCKER_CERTS@', certs_string) \ .replace('@WELL_KNOWN_HOSTS@', well_known_string) \ .replace('@KUBE_IP@', kube_ip) \ @@ -1000,6 +1040,8 @@ def main(): parser.add_argument("--kubeadm_token", type=str, required=True) parser.add_argument("--kms_encyr_key_id", type=str, required=False) parser.add_argument("--region_id", type=str, default=None) + parser.add_argument("--label", type=str, default=[], required=False, action='append') + parser.add_argument("--image", type=str, default=[], required=False, action='append') args, unknown = parser.parse_known_args() ins_key_path = args.ins_key @@ -1015,6 +1057,8 @@ def main(): kube_ip = args.kube_ip kubeadm_token = args.kubeadm_token region_id = args.region_id + pre_pull_images = args.image + additional_labels = args.label global zone zone = region_id @@ -1062,9 +1106,9 @@ def main(): api_url = os.environ["API"] api_token = os.environ["API_TOKEN"] ins_id, ins_ip = run_instance(api_url, api_token, resource_name, ins_type, cloud_region, run_id, ins_hdd, ins_img, ins_key_path, - "pipeline", ins_type, is_spot, kube_ip, kubeadm_token) + "pipeline", ins_type, is_spot, kube_ip, kubeadm_token, pre_pull_images) nodename = verify_regnode(ins_id, num_rep, time_rep, api) - label_node(nodename, run_id, api, cluster_name, cluster_role, cloud_region) + label_node(nodename, run_id, api, cluster_name, cluster_role, cloud_region, additional_labels) pipe_log('Node created:\n' '- {}\n' '- {}'.format(ins_id, ins_ip)) diff --git a/scripts/autoscaling/nodeup.py b/scripts/autoscaling/nodeup.py index c2cbcd783e..68c832fb11 100644 --- a/scripts/autoscaling/nodeup.py +++ b/scripts/autoscaling/nodeup.py @@ -38,7 +38,8 @@ def main(): parser.add_argument("--kms_encyr_key_id", type=str, required=False) parser.add_argument("--region_id", type=str, default=None) parser.add_argument("--cloud", type=str, default=None) - + parser.add_argument("--label", type=str, default=[], required=False, action='append') + parser.add_argument("--image", type=str, default=[], required=False, action='append') args = parser.parse_args() ins_key = args.ins_key @@ -57,6 +58,8 @@ def main(): kms_encyr_key_id = args.kms_encyr_key_id region_id = args.region_id cloud = args.cloud + pre_pull_images = args.image + additional_labels = args.label if not kube_ip or not kubeadm_token: raise RuntimeError('Kubernetes configuration is required to create a new node') @@ -97,13 +100,14 @@ def main(): if not ins_id: ins_id, ins_ip = cloud_provider.run_instance(is_spot, bid_price, ins_type, ins_hdd, ins_img, ins_key, run_id, - kms_encyr_key_id, num_rep, time_rep, kube_ip, kubeadm_token) + kms_encyr_key_id, num_rep, time_rep, kube_ip, kubeadm_token, + pre_pull_images) cloud_provider.check_instance(ins_id, run_id, num_rep, time_rep) nodename, nodename_full = cloud_provider.get_instance_names(ins_id) utils.pipe_log('Waiting for instance {} registration in cluster with name {}'.format(ins_id, nodename)) nodename = kube_provider.verify_regnode(ins_id, nodename, nodename_full, num_rep, time_rep) - kube_provider.label_node(nodename, run_id, cluster_name, cluster_role, region_id) + kube_provider.label_node(nodename, run_id, cluster_name, cluster_role, region_id, additional_labels) utils.pipe_log('Node created:\n' '- {}\n' diff --git a/workflows/pipe-common/pipeline/autoscaling/awsprovider.py b/workflows/pipe-common/pipeline/autoscaling/awsprovider.py index 0fe9c1ce92..8e3efe265e 100644 --- a/workflows/pipe-common/pipeline/autoscaling/awsprovider.py +++ b/workflows/pipe-common/pipeline/autoscaling/awsprovider.py @@ -68,14 +68,14 @@ def __init__(self, cloud_region, num_rep=10): self.ec2 = boto3.client('ec2', config=Config(retries={'max_attempts': BOTO3_RETRY_COUNT})) def run_instance(self, is_spot, bid_price, ins_type, ins_hdd, ins_img, ins_key, run_id, kms_encyr_key_id, - num_rep, time_rep, kube_ip, kubeadm_token): + num_rep, time_rep, kube_ip, kubeadm_token, pre_pull_images=[]): ins_id, ins_ip = self.__check_spot_request_exists(num_rep, run_id, time_rep) if ins_id: return ins_id, ins_ip swap_size = utils.get_swap_size(self.cloud_region, ins_type, is_spot, "AWS") user_data_script = utils.get_user_data_script(self.cloud_region, ins_type, ins_img, kube_ip, kubeadm_token, - swap_size) + swap_size, pre_pull_images) if is_spot: ins_id, ins_ip = self.__find_spot_instance(bid_price, run_id, ins_img, ins_type, ins_key, ins_hdd, kms_encyr_key_id, user_data_script, num_rep, time_rep, swap_size) diff --git a/workflows/pipe-common/pipeline/autoscaling/azureprovider.py b/workflows/pipe-common/pipeline/autoscaling/azureprovider.py index 870826c0a3..5b11006ebc 100644 --- a/workflows/pipe-common/pipeline/autoscaling/azureprovider.py +++ b/workflows/pipe-common/pipeline/autoscaling/azureprovider.py @@ -54,12 +54,12 @@ def __init__(self, zone): self.resource_group_name = os.environ["AZURE_RESOURCE_GROUP"] def run_instance(self, is_spot, bid_price, ins_type, ins_hdd, ins_img, ins_key, run_id, kms_encyr_key_id, - num_rep, time_rep, kube_ip, kubeadm_token): + num_rep, time_rep, kube_ip, kubeadm_token, pre_pull_images=[]): try: ins_key = utils.read_ssh_key(ins_key) swap_size = utils.get_swap_size(self.zone, ins_type, is_spot, "AZURE") user_data_script = utils.get_user_data_script(self.zone, ins_type, ins_img, kube_ip, kubeadm_token, - swap_size) + swap_size, pre_pull_images) instance_name = "az-" + uuid.uuid4().hex[0:16] access_config = utils.get_access_config(self.cloud_region) disable_external_access = False diff --git a/workflows/pipe-common/pipeline/autoscaling/gcpprovider.py b/workflows/pipe-common/pipeline/autoscaling/gcpprovider.py index f7c2d1c1a4..71e25bdbc3 100644 --- a/workflows/pipe-common/pipeline/autoscaling/gcpprovider.py +++ b/workflows/pipe-common/pipeline/autoscaling/gcpprovider.py @@ -49,11 +49,11 @@ def __init__(self, cloud_region): self.client = discovery.build('compute', 'v1') def run_instance(self, is_spot, bid_price, ins_type, ins_hdd, ins_img, ins_key, run_id, kms_encyr_key_id, - num_rep, time_rep, kube_ip, kubeadm_token): + num_rep, time_rep, kube_ip, kubeadm_token, pre_pull_images=[]): ssh_pub_key = utils.read_ssh_key(ins_key) swap_size = utils.get_swap_size(self.cloud_region, ins_type, is_spot, "GCP") user_data_script = utils.get_user_data_script(self.cloud_region, ins_type, ins_img, - kube_ip, kubeadm_token, swap_size) + kube_ip, kubeadm_token, swap_size, pre_pull_images) instance_type, gpu_type, gpu_count = self.parse_instance_type(ins_type) machine_type = 'zones/{}/machineTypes/{}'.format(self.cloud_region, instance_type) diff --git a/workflows/pipe-common/pipeline/autoscaling/kubeprovider.py b/workflows/pipe-common/pipeline/autoscaling/kubeprovider.py index 15af008810..e8a6b7e8d7 100644 --- a/workflows/pipe-common/pipeline/autoscaling/kubeprovider.py +++ b/workflows/pipe-common/pipeline/autoscaling/kubeprovider.py @@ -82,7 +82,7 @@ def delete_kubernetes_node_by_name(self, node_name): } pykube.Node(self.api, obj).delete() - def label_node(self, nodename, run_id, cluster_name, cluster_role, cloud_region): + def label_node(self, nodename, run_id, cluster_name, cluster_role, cloud_region, additional_labels): utils.pipe_log('Assigning instance {} to RunID: {}'.format(nodename, run_id)) obj = { "apiVersion": "v1", @@ -96,6 +96,14 @@ def label_node(self, nodename, run_id, cluster_name, cluster_role, cloud_region) } } + if additional_labels: + for label in additional_labels: + label_parts = label.split("=") + if len(label_parts) == 1: + obj["metadata"]["labels"][label_parts[0]] = None + else: + obj["metadata"]["labels"][label_parts[0]] = label_parts[1] + if cluster_name: obj["metadata"]["labels"]["cp-cluster-name"] = cluster_name if cluster_role: diff --git a/workflows/pipe-common/pipeline/autoscaling/utils.py b/workflows/pipe-common/pipeline/autoscaling/utils.py index 50c90e6392..0bff716d92 100644 --- a/workflows/pipe-common/pipeline/autoscaling/utils.py +++ b/workflows/pipe-common/pipeline/autoscaling/utils.py @@ -18,6 +18,7 @@ import json import math from pipeline import Logger, TaskStatus, PipelineAPI, pack_script_contents +import jwt NETWORKS_PARAM = "cluster.networks.config" NODEUP_TASK = "InitializeNode" @@ -29,6 +30,21 @@ script_path = None +def is_run_id_numerical(run_id): + try: + int(run_id) + return True + except ValueError: + return False + + +def is_api_logging_enabled(): + global api_token + global api_url + global current_run_id + return is_run_id_numerical(current_run_id) and api_url and api_token + + def pipe_log_init(run_id): global api_token global api_url @@ -38,7 +54,7 @@ def pipe_log_init(run_id): api_url = os.environ["API"] api_token = os.environ["API_TOKEN"] - if not api_url or not api_token: + if not is_api_logging_enabled(): logging.basicConfig(filename='nodeup.log', level=logging.INFO, format='%(asctime)s %(message)s') @@ -48,7 +64,7 @@ def pipe_log(message, status=TaskStatus.RUNNING): global script_path global current_run_id - if api_url and api_token: + if is_api_logging_enabled(): Logger.log_task_event(NODEUP_TASK, '[{}] {}'.format(current_run_id, message), run_id=current_run_id, @@ -68,7 +84,7 @@ def pipe_log_warn(message): global script_path global current_run_id - if api_url and api_token: + if is_api_logging_enabled(): Logger.warn('[{}] {}'.format(current_run_id, message), task_name=NODEUP_TASK, run_id=current_run_id, @@ -248,7 +264,20 @@ def replace_swap(swap_size, init_script): return init_script -def get_user_data_script(cloud_region, ins_type, ins_img, kube_ip, kubeadm_token, swap_size): +def replace_docker_images(pre_pull_images, user_data_script): + global api_token + payload = jwt.decode(api_token, verify=False) + if 'sub' in payload: + subject = payload['sub'] + user_data_script = user_data_script \ + .replace("@PRE_PULL_DOCKERS@", ",".join(pre_pull_images)) \ + .replace("@API_USER@", subject) + return user_data_script + else: + raise RuntimeError("Pre-pulled docker initialization failed: unable to parse JWT token for docker auth.") + + +def get_user_data_script(cloud_region, ins_type, ins_img, kube_ip, kubeadm_token, swap_size, pre_pull_images=[]): allowed_instance = get_allowed_instance_image(cloud_region, ins_type, ins_img) if allowed_instance and allowed_instance["init_script"]: init_script = open(allowed_instance["init_script"], 'r') @@ -258,6 +287,8 @@ def get_user_data_script(cloud_region, ins_type, ins_img, kube_ip, kubeadm_token init_script.close() user_data_script = replace_proxies(cloud_region, user_data_script) user_data_script = replace_swap(swap_size, user_data_script) + if pre_pull_images: + user_data_script = replace_docker_images(pre_pull_images, user_data_script) user_data_script = user_data_script.replace('@DOCKER_CERTS@', certs_string)\ .replace('@WELL_KNOWN_HOSTS@', well_known_string)\ .replace('@KUBE_IP@', kube_ip)\