Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add validation for what flow the run_id belongs to #1452

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,29 @@ def get_existing_deployment(cls, name):
)
return None

@classmethod
def get_execution(cls, name):
workflow = ArgoClient(namespace=KUBERNETES_NAMESPACE).get_workflow(name)
if workflow is not None:
try:
return (
workflow["metadata"]["annotations"]["metaflow/owner"],
workflow["metadata"]["annotations"]["metaflow/production_token"],
workflow["metadata"]["annotations"]["metaflow/flow_name"],
workflow["metadata"]["annotations"].get(
"metaflow/branch_name", None
),
workflow["metadata"]["annotations"].get(
"metaflow/project_name", None
),
)
except KeyError:
raise ArgoWorkflowsException(
"A non-metaflow workflow *%s* already exists in Argo Workflows."
% name
)
return None

def _process_parameters(self):
parameters = {}
has_schedule = self.flow._flow_decorators.get("schedule") is not None
Expand Down
120 changes: 96 additions & 24 deletions metaflow/plugins/argo/argo_workflows_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class IncorrectProductionToken(MetaflowException):
headline = "Incorrect production token"


class RunIdMismatch(MetaflowException):
headline = "Run ID mismatch"


class IncorrectMetadataServiceVersion(MetaflowException):
headline = "Incorrect version for metaflow service"

Expand Down Expand Up @@ -333,12 +337,7 @@ def resolve_workflow_name(obj, name):
workflow_name = "%s-%s" % (workflow_name[:242], name_hash)
obj._is_workflow_name_modified = True
if not VALID_NAME.search(workflow_name):
workflow_name = (
re.compile(r"^[^A-Za-z0-9]+")
.sub("", workflow_name)
.replace("_", "")
.lower()
)
workflow_name = sanitize_for_argo(workflow_name)
obj._is_workflow_name_modified = True
else:
if name and not VALID_NAME.search(name):
Expand All @@ -363,12 +362,7 @@ def resolve_workflow_name(obj, name):
raise ArgoWorkflowsNameTooLong(msg)

if not VALID_NAME.search(workflow_name):
workflow_name = (
re.compile(r"^[^A-Za-z0-9]+")
.sub("", workflow_name)
.replace("_", "")
.lower()
)
workflow_name = sanitize_for_argo(workflow_name)
obj._is_workflow_name_modified = True

return workflow_name, token_prefix.lower(), is_project
Expand Down Expand Up @@ -618,19 +612,17 @@ def _token_instructions(flow_name, prev_user):
"Please reach out to them to get the token. Once you have it, call "
"this command:"
)
obj.echo(
" argo-workflows suspend RUN_ID --authorize MY_TOKEN RUN_ID", fg="green"
)
obj.echo(" argo-workflows suspend RUN_ID --authorize MY_TOKEN", fg="green")
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_token(obj.workflow_name, obj.token_prefix, authorize, _token_instructions)
validate_run_id(
obj.workflow_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Verify that user is trying to change an Argo workflow
if not run_id.startswith("argo-"):
raise MetaflowException("Argo workflow execution id's start with 'argo-'")
# Trim prefix from run_id
name = run_id[5:]

workflow_suspended = ArgoWorkflows.suspend(name)
Expand Down Expand Up @@ -662,19 +654,19 @@ def _token_instructions(flow_name, prev_user):
"this command:"
)
obj.echo(
" argo-workflows unsuspend RUN_ID --authorize MY_TOKEN RUN_ID",
" argo-workflows unsuspend RUN_ID --authorize MY_TOKEN",
fg="green",
)
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_token(obj.workflow_name, obj.token_prefix, authorize, _token_instructions)
validate_run_id(
obj.workflow_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Verify that user is trying to change an Argo workflow
if not run_id.startswith("argo-"):
raise MetaflowException("Argo workflow execution id's start with 'argo-'")
# Trim prefix from run_id
name = run_id[5:]

workflow_suspended = ArgoWorkflows.unsuspend(name)
Expand Down Expand Up @@ -721,3 +713,83 @@ def validate_token(name, token_prefix, authorize, instructions_fn=None):

store_token(token_prefix, token)
return True


def validate_run_id(
workflow_name, token_prefix, authorize, run_id, instructions_fn=None
):
"""
Validates that a run_id adheres to the Argo Workflows naming rules, and
that it belongs to the current flow (accounting for project branch as well).
"""
# Verify that user is trying to change an Argo workflow
if not run_id.startswith("argo-"):
raise RunIdMismatch(
"Run IDs for flows executed through Argo Workflows begin with 'argo-'"
)

# Verify that run_id belongs to the Flow, and that branches match
name = run_id[5:]
workflow = ArgoWorkflows.get_execution(name)
if workflow is None:
raise MetaflowException("Could not find workflow *%s* on Argo Workflows" % name)

owner, token, flow_name, branch_name, project_name = workflow

# Verify we are operating on the correct Flow file compared to the running one.
# Without this check, using --name could be used to run commands for arbitrary run_id's, disregarding the Flow in the file.
if current.flow_name != flow_name:
raise RunIdMismatch(
"The workflow with the run_id *%s* belongs to the flow *%s*, not for the flow *%s*."
% (run_id, flow_name, current.flow_name)
)

if project_name is not None:
# Verify we are operating on the correct project.
# Perform match with separators to avoid substrings matching
# e.g. 'test_proj' and 'test_project' should count as a mismatch.
project_part = "%s." % sanitize_for_argo(project_name)
if (
current.get("project_name") != project_name
and project_part not in workflow_name
):
raise RunIdMismatch(
"The workflow belongs to the project *%s*. "
"Please use the project decorator or --name to target the correct project"
% project_name
)

# Verify we are operating on the correct branch.
# Perform match with separators to avoid substrings matching.
# e.g. 'user.tes' and 'user.test' should count as a mismatch.
branch_part = ".%s." % sanitize_for_argo(branch_name)
if (
current.get("branch_name") != branch_name
and branch_part not in workflow_name
):
raise RunIdMismatch(
"The workflow belongs to the branch *%s*. "
"Please use --branch, --production or --name to target the correct branch"
% branch_name
)

# Verify that the production tokens match. We do not want to cache the token that was used though,
# as the operations that require run_id validation can target runs not authored from the local environment
if authorize is None:
authorize = load_token(token_prefix)
elif authorize.startswith("production:"):
authorize = authorize[11:]

if owner != get_username() and authorize != token:
if instructions_fn:
instructions_fn(flow_name=name, prev_user=owner)
raise IncorrectProductionToken("Try again with the correct production token.")

return True


def sanitize_for_argo(text):
"""
Sanitizes a string so it does not contain characters that are not permitted in Argo Workflow resource names.
"""
return re.compile(r"^[^A-Za-z0-9]+").sub("", text).replace("_", "").lower()