diff --git a/metaflow/plugins/argo/argo_client.py b/metaflow/plugins/argo/argo_client.py index 6787c631883..78494684882 100644 --- a/metaflow/plugins/argo/argo_client.py +++ b/metaflow/plugins/argo/argo_client.py @@ -135,6 +135,44 @@ def delete_workflow_template(self, name): json.loads(e.body)["message"] if e.body is not None else e.reason ) + def terminate_workflow(self, run_id): + client = self._kubernetes_client.get() + try: + workflow = client.CustomObjectsApi().get_namespaced_custom_object( + group=self._group, + version=self._version, + namespace=self._namespace, + plural="workflows", + name=run_id, + ) + except client.rest.ApiException as e: + raise ArgoClientException( + json.loads(e.body)["message"] if e.body is not None else e.reason + ) + + if workflow["status"]["finishedAt"] is not None: + raise ArgoClientException( + "Cannot terminate an execution that has already finished." + ) + if workflow["spec"].get("shutdown") == "Terminate": + raise ArgoClientException("Execution has already been terminated.") + + try: + body = {"spec": workflow["spec"]} + body["spec"]["shutdown"] = "Terminate" + return client.CustomObjectsApi().patch_namespaced_custom_object( + group=self._group, + version=self._version, + namespace=self._namespace, + plural="workflows", + name=run_id, + body=body, + ) + except client.rest.ApiException as e: + raise ArgoClientException( + json.loads(e.body)["message"] if e.body is not None else e.reason + ) + def trigger_workflow_template(self, name, parameters={}): client = self._kubernetes_client.get() body = { diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 95d72c04544..31b861ec00c 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -165,6 +165,28 @@ def delete(name): ) return True + @staticmethod + def terminate(flow_name, run_id): + client = ArgoClient(namespace=KUBERNETES_NAMESPACE) + + # Verify that user is trying to terminate an Argo workflow + if not run_id.startswith("argo-"): + raise ArgoWorkflowsException( + "No execution found for {flow_name}/{run_id} in Argo Workflows.".format( + flow_name=flow_name, run_id=run_id + ) + ) + trimmed_run_id = run_id[5:] + + response = client.terminate_workflow(trimmed_run_id) + if response is None: + raise ArgoWorkflowsException( + "No execution found for {flow_name}/{run_id} in Argo Workflows.".format( + flow_name=flow_name, run_id=run_id + ) + ) + return True + @classmethod def trigger(cls, name, parameters=None): if parameters is None: diff --git a/metaflow/plugins/argo/argo_workflows_cli.py b/metaflow/plugins/argo/argo_workflows_cli.py index d815d155c97..adb70f5a333 100644 --- a/metaflow/plugins/argo/argo_workflows_cli.py +++ b/metaflow/plugins/argo/argo_workflows_cli.py @@ -602,3 +602,26 @@ def echo_token_instructions(obj, name, prev_user, cmd_name, cmd_description=None 'See "Organizing Results" at docs.metaflow.org for more information ' "about production tokens." ) + + +@argo_workflows.command(help="Terminate flow execution on Argo Workflows.") +@click.option( + "--authorize", + default=None, + type=str, + help="Authorize the termination with a production token", +) +@click.argument("run-id", required=True, type=str) +@click.pass_obj +def terminate(obj, run_id, authorize=None): + validate_token(obj.workflow_name, obj.token_prefix, obj, authorize, "terminate") + obj.echo( + "Terminating run *{run_id}* for {flow_name} ...".format( + run_id=run_id, flow_name=obj.flow.name + ), + bold=True, + ) + + terminated = ArgoWorkflows.terminate(obj.flow.name, run_id) + if terminated: + obj.echo("\nRun terminated.")