Skip to content

Commit

Permalink
feature: add validation for what flow the run_id belongs to (#1452)
Browse files Browse the repository at this point in the history
* add validation for what flow the run_id belongs to

* addressing review comments

* test using workflow production token for authorization directly.

* rework run_id validation to perform separate checks for project_name and branch_name, accounting for --name. refactor Argo resource name sanitization

* cleanup
  • Loading branch information
saikonen authored Jun 26, 2023
1 parent b2a2179 commit 63cc392
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 24 deletions.
23 changes: 23 additions & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,29 @@ def get_existing_deployment(cls, name):
)
return None

@classmethod
def get_execution(cls, name):
workflow = ArgoClient(namespace=KUBERNETES_NAMESPACE).get_workflow(name)
if workflow is not None:
try:
return (
workflow["metadata"]["annotations"]["metaflow/owner"],
workflow["metadata"]["annotations"]["metaflow/production_token"],
workflow["metadata"]["annotations"]["metaflow/flow_name"],
workflow["metadata"]["annotations"].get(
"metaflow/branch_name", None
),
workflow["metadata"]["annotations"].get(
"metaflow/project_name", None
),
)
except KeyError:
raise ArgoWorkflowsException(
"A non-metaflow workflow *%s* already exists in Argo Workflows."
% name
)
return None

def _process_parameters(self):
parameters = {}
has_schedule = self.flow._flow_decorators.get("schedule") is not None
Expand Down
120 changes: 96 additions & 24 deletions metaflow/plugins/argo/argo_workflows_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class IncorrectProductionToken(MetaflowException):
headline = "Incorrect production token"


class RunIdMismatch(MetaflowException):
headline = "Run ID mismatch"


class IncorrectMetadataServiceVersion(MetaflowException):
headline = "Incorrect version for metaflow service"

Expand Down Expand Up @@ -333,12 +337,7 @@ def resolve_workflow_name(obj, name):
workflow_name = "%s-%s" % (workflow_name[:242], name_hash)
obj._is_workflow_name_modified = True
if not VALID_NAME.search(workflow_name):
workflow_name = (
re.compile(r"^[^A-Za-z0-9]+")
.sub("", workflow_name)
.replace("_", "")
.lower()
)
workflow_name = sanitize_for_argo(workflow_name)
obj._is_workflow_name_modified = True
else:
if name and not VALID_NAME.search(name):
Expand All @@ -363,12 +362,7 @@ def resolve_workflow_name(obj, name):
raise ArgoWorkflowsNameTooLong(msg)

if not VALID_NAME.search(workflow_name):
workflow_name = (
re.compile(r"^[^A-Za-z0-9]+")
.sub("", workflow_name)
.replace("_", "")
.lower()
)
workflow_name = sanitize_for_argo(workflow_name)
obj._is_workflow_name_modified = True

return workflow_name, token_prefix.lower(), is_project
Expand Down Expand Up @@ -618,19 +612,17 @@ def _token_instructions(flow_name, prev_user):
"Please reach out to them to get the token. Once you have it, call "
"this command:"
)
obj.echo(
" argo-workflows suspend RUN_ID --authorize MY_TOKEN RUN_ID", fg="green"
)
obj.echo(" argo-workflows suspend RUN_ID --authorize MY_TOKEN", fg="green")
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_token(obj.workflow_name, obj.token_prefix, authorize, _token_instructions)
validate_run_id(
obj.workflow_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Verify that user is trying to change an Argo workflow
if not run_id.startswith("argo-"):
raise MetaflowException("Argo workflow execution id's start with 'argo-'")
# Trim prefix from run_id
name = run_id[5:]

workflow_suspended = ArgoWorkflows.suspend(name)
Expand Down Expand Up @@ -662,19 +654,19 @@ def _token_instructions(flow_name, prev_user):
"this command:"
)
obj.echo(
" argo-workflows unsuspend RUN_ID --authorize MY_TOKEN RUN_ID",
" argo-workflows unsuspend RUN_ID --authorize MY_TOKEN",
fg="green",
)
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_token(obj.workflow_name, obj.token_prefix, authorize, _token_instructions)
validate_run_id(
obj.workflow_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Verify that user is trying to change an Argo workflow
if not run_id.startswith("argo-"):
raise MetaflowException("Argo workflow execution id's start with 'argo-'")
# Trim prefix from run_id
name = run_id[5:]

workflow_suspended = ArgoWorkflows.unsuspend(name)
Expand Down Expand Up @@ -721,3 +713,83 @@ def validate_token(name, token_prefix, authorize, instructions_fn=None):

store_token(token_prefix, token)
return True


def validate_run_id(
workflow_name, token_prefix, authorize, run_id, instructions_fn=None
):
"""
Validates that a run_id adheres to the Argo Workflows naming rules, and
that it belongs to the current flow (accounting for project branch as well).
"""
# Verify that user is trying to change an Argo workflow
if not run_id.startswith("argo-"):
raise RunIdMismatch(
"Run IDs for flows executed through Argo Workflows begin with 'argo-'"
)

# Verify that run_id belongs to the Flow, and that branches match
name = run_id[5:]
workflow = ArgoWorkflows.get_execution(name)
if workflow is None:
raise MetaflowException("Could not find workflow *%s* on Argo Workflows" % name)

owner, token, flow_name, branch_name, project_name = workflow

# Verify we are operating on the correct Flow file compared to the running one.
# Without this check, using --name could be used to run commands for arbitrary run_id's, disregarding the Flow in the file.
if current.flow_name != flow_name:
raise RunIdMismatch(
"The workflow with the run_id *%s* belongs to the flow *%s*, not for the flow *%s*."
% (run_id, flow_name, current.flow_name)
)

if project_name is not None:
# Verify we are operating on the correct project.
# Perform match with separators to avoid substrings matching
# e.g. 'test_proj' and 'test_project' should count as a mismatch.
project_part = "%s." % sanitize_for_argo(project_name)
if (
current.get("project_name") != project_name
and project_part not in workflow_name
):
raise RunIdMismatch(
"The workflow belongs to the project *%s*. "
"Please use the project decorator or --name to target the correct project"
% project_name
)

# Verify we are operating on the correct branch.
# Perform match with separators to avoid substrings matching.
# e.g. 'user.tes' and 'user.test' should count as a mismatch.
branch_part = ".%s." % sanitize_for_argo(branch_name)
if (
current.get("branch_name") != branch_name
and branch_part not in workflow_name
):
raise RunIdMismatch(
"The workflow belongs to the branch *%s*. "
"Please use --branch, --production or --name to target the correct branch"
% branch_name
)

# Verify that the production tokens match. We do not want to cache the token that was used though,
# as the operations that require run_id validation can target runs not authored from the local environment
if authorize is None:
authorize = load_token(token_prefix)
elif authorize.startswith("production:"):
authorize = authorize[11:]

if owner != get_username() and authorize != token:
if instructions_fn:
instructions_fn(flow_name=name, prev_user=owner)
raise IncorrectProductionToken("Try again with the correct production token.")

return True


def sanitize_for_argo(text):
"""
Sanitizes a string so it does not contain characters that are not permitted in Argo Workflow resource names.
"""
return re.compile(r"^[^A-Za-z0-9]+").sub("", text).replace("_", "").lower()

0 comments on commit 63cc392

Please sign in to comment.