Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass in project, domain and version in register rather than serialize #288

Merged
merged 7 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions flytekit/clis/flyte_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os as _os
import stat as _stat
import sys as _sys
from typing import List, Tuple

import click as _click
import requests as _requests
Expand All @@ -11,11 +12,13 @@
from flyteidl.admin import workflow_pb2 as _workflow_pb2
from flyteidl.core import identifier_pb2 as _identifier_pb2
from flyteidl.core import literals_pb2 as _literals_pb2
from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType as _GeneratedProtocolMessageType

from flytekit import __version__
from flytekit.clients import friendly as _friendly_client
from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map
from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map
from flytekit.clis.helpers import hydrate_registration_parameters
from flytekit.clis.helpers import parse_args_into_dict as _parse_args_into_dict
from flytekit.common import launch_plan as _launch_plan_common
from flytekit.common import utils as _utils
Expand Down Expand Up @@ -286,6 +289,7 @@ def _render_schedule_expr(lp):
_PROJECT_FLAGS = ["-p", "--project"]
_DOMAIN_FLAGS = ["-d", "--domain"]
_NAME_FLAGS = ["-n", "--name"]
_VERSION_FLAGS = ["-v", "--version"]
_HOST_FLAGS = ["-h", "--host"]
_PRINCIPAL_FLAGS = ["-r", "--principal"]
_INSECURE_FLAGS = ["-i", "--insecure"]
Expand Down Expand Up @@ -1478,64 +1482,69 @@ def register_project(identifier, name, description, host, insecure):
_click.echo("Registered project [id: {}, name: {}, description: {}]".format(identifier, name, description))


def _extract_pair(identifier_file, object_file):
"""
:param Text identifier_file:
:param Text object_file:
:rtype: (flyteidl.core.identifier_pb2.Identifier, T)
"""
resource_map = {
_identifier_pb2.LAUNCH_PLAN: _launch_plan_pb2.LaunchPlanSpec,
_identifier_pb2.WORKFLOW: _workflow_pb2.WorkflowSpec,
_identifier_pb2.TASK: _task_pb2.TaskSpec,
}
id = _load_proto_from_file(_identifier_pb2.Identifier, identifier_file)
if id.resource_type not in resource_map:
_resource_map = {
_identifier_pb2.LAUNCH_PLAN: _launch_plan_pb2.LaunchPlanSpec,
_identifier_pb2.WORKFLOW: _workflow_pb2.WorkflowSpec,
_identifier_pb2.TASK: _task_pb2.TaskSpec,
}


def _get_entity(
project: str, domain: str, version: str, resource_type_name: str, proto_file: str
) -> Tuple[_identifier_pb2.Identifier, _GeneratedProtocolMessageType]:
resource_type = _identifier_pb2.ResourceType.Value(resource_type_name.upper())
pb2_type = _resource_map.get(resource_type, None)
if not pb2_type:
raise _user_exceptions.FlyteAssertion(
f"Resource type found in identifier {id.resource_type} invalid, must be launch plan, " f"task, or workflow"
f"Resource type found in identifier {resource_type_name} invalid, must be launch plan, task, or workflow"
)
entity = _load_proto_from_file(resource_map[id.resource_type], object_file)
return id, entity

entity = _load_proto_from_file(pb2_type, proto_file)
return hydrate_registration_parameters(resource_type, project, domain, version, entity)


def _extract_files(file_paths):
def _extract_files(project: str, domain: str, version: str, file_paths: List[str]):
"""
:param file_paths:
:rtype: List[(flyteidl.core.identifier_pb2.Identifier, T)]
"""
# Get a manual iterator because we're going to grab files two at a time.
# The identifier file will always come first because the names are always the same and .identifier.pb sorts before
# .pb

results = []
filename_iterator = iter(file_paths)
for identifier_file in filename_iterator:
object_file = next(filename_iterator)
id, entity = _extract_pair(identifier_file, object_file)
results.append((id, entity))
for proto_file in filename_iterator:
# Serialized proto files are of the form: foo.bar.<resource_type>.pb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add the leading 007_ or whatever to the comment too? the index to keep things in order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

file_name_parts = proto_file.split(".")
if file_name_parts[-1] != "pb":
_click.echo(f"Skipping non-proto file {proto_file}")
continue
if len(file_name_parts) <= 2:
raise Exception(f"Serialized proto file {proto_file} has unrecognized file name")
resource_type_name = file_name_parts[-2]
results.append(_get_entity(project, domain, version, resource_type_name, proto_file))

return results


@_flyte_cli.command("register-files", cls=_FlyteSubCommand)
@_click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to register with.")
@_click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to register with.")
@_click.option(*_VERSION_FLAGS, required=True, help="The entity version to register with")
@_host_option
@_insecure_option
@_click.argument(
"files", type=_click.Path(exists=True), nargs=-1,
)
def register_files(host, insecure, files):
def register_files(project, domain, version, host, insecure, files):
"""
Given a list of files, this will (after sorting the input list), attempt to register them against Flyte Admin.
This command expects the files to be the output of the pyflyte serialize command. See the code there for more
information. Valid files need to be:
information. Valid files need to be:\n
* Ordered in the order that you want registration to happen. pyflyte should have done the topological sort
for you and produced file that have a prefix that sets the correct order.
for you and produced file that have a prefix that sets the correct order.\n
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we still need the topological sort to guarantee that we register tasks before workflows that reference them, and workflows before launch plans that reference them?

* Of the correct type. That is, they should be the serialized form of one of these Flyte IDL objects
(or an identifier object).
- flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans
- flyteidl.admin.workflow_pb2.WorkflowSpec for workflows
- flyteidl.admin.task_pb2.TaskSpec for tasks
* Each file needs to be accompanied by an identifier file. We can relax this constraint in the future.
(or an identifier object).\n
- flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans\n
- flyteidl.admin.workflow_pb2.WorkflowSpec for workflows\n
- flyteidl.admin.task_pb2.TaskSpec for tasks\n

:param host:
:param insecure:
Expand All @@ -1550,7 +1559,7 @@ def register_files(host, insecure, files):
for f in files:
_click.echo(f" {f}")

flyte_entities_list = _extract_files(files)
flyte_entities_list = _extract_files(project, domain, version, files)
for id, flyte_entity in flyte_entities_list:
try:
if id.resource_type == _identifier_pb2.LAUNCH_PLAN:
Expand Down
83 changes: 83 additions & 0 deletions flytekit/clis/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Tuple

import six as _six
from flyteidl.core import identifier_pb2 as _identifier_pb2
from flyteidl.core import workflow_pb2 as _workflow_pb2
from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType as _GeneratedProtocolMessageType

from flytekit.common.types.helpers import get_sdk_type_from_literal_type as _get_sdk_type_from_literal_type
from flytekit.models import literals as _literals
Expand Down Expand Up @@ -79,3 +84,81 @@ def str2bool(str):
:rtype: bool
"""
return not str.lower() in ["false", "0", "off", "no"]


def _hydrate_identifier(
project: str, domain: str, version: str, identifier: _identifier_pb2.Identifier
) -> _identifier_pb2.Identifier:
identifier.project = identifier.project or project
identifier.domain = identifier.domain or domain
identifier.version = identifier.version or version
return identifier


def _hydrate_workflow_template(
project: str, domain: str, version: str, template: _workflow_pb2.WorkflowTemplate
) -> _workflow_pb2.WorkflowTemplate:
refreshed_nodes = []
for node in template.nodes:
if node.HasField("task_node"):
task_node = node.task_node
task_node.reference_id.CopyFrom(_hydrate_identifier(project, domain, version, task_node.reference_id))
node.task_node.CopyFrom(task_node)
elif node.HasField("workflow_node"):
workflow_node = node.workflow_node
if workflow_node.HasField("launchplan_ref"):
workflow_node.launchplan_ref.CopyFrom(
_hydrate_identifier(project, domain, version, workflow_node.launchplan_ref)
)
elif workflow_node.HasField("sub_workflow_ref"):
workflow_node.sub_workflow_ref.CopyFrom(
_hydrate_identifier(project, domain, version, workflow_node.sub_workflow_ref)
)
node.workflow_node.CopyFrom(workflow_node)
refreshed_nodes.append(node)
# Reassign nodes with the newly hydrated ones.
del template.nodes[:]
template.nodes.extend(refreshed_nodes)
return template


def hydrate_registration_parameters(
resource_type: _identifier_pb2.ResourceType,
project: str,
domain: str,
version: str,
entity: _GeneratedProtocolMessageType,
) -> Tuple[_identifier_pb2.Identifier, _GeneratedProtocolMessageType]:
"""
This is called at registration time to fill out identifier fields (e.g. project, domain, version) that are mutable.
"""
if resource_type == _identifier_pb2.LAUNCH_PLAN:
entity.workflow_id.CopyFrom(_hydrate_identifier(project, domain, version, entity.workflow_id))
return (
_identifier_pb2.Identifier(
resource_type=resource_type,
project=project,
domain=domain,
name=entity.workflow_id.name,
version=version,
),
entity,
)

entity.template.id.CopyFrom(_hydrate_identifier(project, domain, version, entity.template.id))
if resource_type == _identifier_pb2.TASK:
return entity.template.id, entity

# Workflows (the only possible entity type at this point) are a little more complicated.
# Workflow nodes that are defined inline with the workflows will be missing project/domain/version so we fill those
# in now.
# (entity is of type flyteidl.admin.workflow_pb2.WorkflowSpec)
entity.template.CopyFrom(_hydrate_workflow_template(project, domain, version, entity.template))
refreshed_sub_workflows = []
for sub_workflow in entity.sub_workflows:
refreshed_sub_workflow = _hydrate_workflow_template(project, domain, version, sub_workflow)
refreshed_sub_workflows.append(refreshed_sub_workflow)
# Reassign subworkflows with the newly hydrated ones.
del entity.sub_workflows[:]
entity.sub_workflows.extend(refreshed_sub_workflows)
return entity.template.id, entity
14 changes: 14 additions & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import click as _click

CTX_PROJECT = "project"
CTX_DOMAIN = "domain"
CTX_VERSION = "version"
CTX_TEST = "test"
CTX_PACKAGES = "pkgs"
CTX_NOTIFICATIONS = "notifications"


project_option = _click.option(
"-p",
"--project",
required=True,
type=str,
help="Flyte project to use. You can have more than one project per repo",
)
domain_option = _click.option(
"-d", "--domain", required=True, type=str, help="This is usually development, staging, or production",
)
15 changes: 13 additions & 2 deletions flytekit/clis/sdk_in_container/launch_plan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging as _logging
import os as _os

import click
import six as _six

from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map
from flytekit.clis.sdk_in_container import constants as _constants
from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PROJECT, CTX_VERSION, domain_option, project_option
from flytekit.common import utils as _utils
from flytekit.common.launch_plan import SdkLaunchPlan as _SdkLaunchPlan
from flytekit.configuration.internal import DOMAIN as _DOMAIN
from flytekit.configuration.internal import IMAGE as _IMAGE
from flytekit.configuration.internal import PROJECT as _PROJECT
from flytekit.configuration.internal import VERSION as _VERSION
from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag
from flytekit.models import launch_plan as _launch_plan_model
from flytekit.models.core import identifier as _identifier
Expand Down Expand Up @@ -139,12 +144,18 @@ def _execute_lp(**kwargs):


@click.group("lp")
@project_option
@domain_option
@click.pass_context
def launch_plans(ctx):
def launch_plans(ctx, project, domain):
"""
Launch plan control group, including executions
"""
pass
ctx.obj[CTX_PROJECT] = project
ctx.obj[CTX_DOMAIN] = domain
_os.environ[_PROJECT.env_var] = project
_os.environ[_DOMAIN.env_var] = domain
_os.environ[_VERSION.env_var] = ctx.obj[CTX_VERSION]


@click.group("execute", cls=LaunchPlanExecuteGroup)
Expand Down
20 changes: 2 additions & 18 deletions flytekit/clis/sdk_in_container/pyflyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import click

from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_VERSION
from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES, CTX_VERSION
from flytekit.clis.sdk_in_container.launch_plan import launch_plans
from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.serialize import serialize
Expand All @@ -20,16 +20,6 @@


@click.group("pyflyte", invoke_without_command=True)
@click.option(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this means that if you have pyflyte -p proj -d domain register workflows it'll have to be changed to pyflyte register -p proj -d domain workflows right?

I know I said it was okay to take out project from serialize but I guess I didn't realize what that entailed. if we're going to break compatibility, should we get rid of pyflyte register entirely? Maybe let's discuss with ketan? I'm okay taking it out completely, leaving pyflyte serialize to be the one and only command you need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kumare3 any thoughts on getting rid of pyflyte register altogether?

@wild-endeavor there's also the pyflyte lp command group too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am all for simplifying this story, but what about existing users and their make files - let's say at lyft

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought backwards compatibility wasn't an explicit concern?

If we want to go that route then we make PROJECT, DOMAIN, VERSION optional params in the pyflyte command group and every subcommand has to check that they're secretly filled in. I see it as simpler to use click features to require the args at the most specific level where they are used

"-p",
"--project",
required=True,
type=str,
help="Flyte project to use. You can have more than one project per repo",
)
@click.option(
"-d", "--domain", required=True, type=str, help="This is usually development, staging, or production",
)
@click.option(
"-c", "--config", required=False, type=str, help="Path to config file for use within container",
)
Expand All @@ -48,7 +38,7 @@
"-i", "--insecure", required=False, type=bool, help="Do not use SSL to connect to Admin",
)
@click.pass_context
def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=None):
def main(ctx, config=None, pkgs=None, version=None, insecure=None):
"""
Entrypoint for all the user commands.
"""
Expand All @@ -60,8 +50,6 @@ def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=No
_logging.getLogger().setLevel(log_level)

ctx.obj = dict()
ctx.obj[CTX_PROJECT] = project
ctx.obj[CTX_DOMAIN] = domain
version = version or _look_up_version_from_image_tag(_IMAGE.get())
ctx.obj[CTX_VERSION] = version

Expand All @@ -78,10 +66,6 @@ def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=No
pkgs = _WORKFLOW_PACKAGES.get()
ctx.obj[CTX_PACKAGES] = pkgs

_os.environ[_internal_config.PROJECT.env_var] = project
_os.environ[_internal_config.DOMAIN.env_var] = domain
_os.environ[_internal_config.VERSION.env_var] = version


def update_configuration_file(config_file_path):
"""
Expand Down
Loading