Skip to content

Commit

Permalink
Pass in project, domain and version in register rather than serialize (
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Dec 23, 2020
1 parent 99b638f commit 812aca4
Show file tree
Hide file tree
Showing 10 changed files with 412 additions and 119 deletions.
62 changes: 40 additions & 22 deletions flytekit/clis/flyte_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os as _os
import stat as _stat
import sys as _sys
from typing import List, Tuple, Union

import click as _click
import requests as _requests
Expand All @@ -11,11 +12,14 @@
from flyteidl.admin import workflow_pb2 as _workflow_pb2
from flyteidl.core import identifier_pb2 as _identifier_pb2
from flyteidl.core import literals_pb2 as _literals_pb2
from flyteidl.core import tasks_pb2 as _core_tasks_pb2
from flyteidl.core import workflow_pb2 as _core_workflow_pb2

from flytekit import __version__
from flytekit.clients import friendly as _friendly_client
from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map
from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map
from flytekit.clis.helpers import hydrate_registration_parameters
from flytekit.clis.helpers import parse_args_into_dict as _parse_args_into_dict
from flytekit.common import launch_plan as _launch_plan_common
from flytekit.common import utils as _utils
Expand Down Expand Up @@ -286,6 +290,7 @@ def _render_schedule_expr(lp):
_PROJECT_FLAGS = ["-p", "--project"]
_DOMAIN_FLAGS = ["-d", "--domain"]
_NAME_FLAGS = ["-n", "--name"]
_VERSION_FLAGS = ["-v", "--version"]
_HOST_FLAGS = ["-h", "--host"]
_PRINCIPAL_FLAGS = ["-r", "--principal"]
_INSECURE_FLAGS = ["-i", "--insecure"]
Expand Down Expand Up @@ -1478,27 +1483,37 @@ def register_project(identifier, name, description, host, insecure):
_click.echo("Registered project [id: {}, name: {}, description: {}]".format(identifier, name, description))


def _extract_pair(identifier_file, object_file):
_resource_map = {
_identifier_pb2.LAUNCH_PLAN: _launch_plan_pb2.LaunchPlanSpec,
_identifier_pb2.WORKFLOW: _workflow_pb2.WorkflowSpec,
_identifier_pb2.TASK: _task_pb2.TaskSpec,
}


def _extract_pair(
identifier_file: str, object_file: str, project: str, domain: str, version: str
) -> Tuple[
_identifier_pb2.Identifier,
Union[_core_tasks_pb2.TaskTemplate, _core_workflow_pb2.WorkflowTemplate, _launch_plan_pb2.LaunchPlanSpec],
]:
"""
:param Text identifier_file:
:param Text object_file:
:rtype: (flyteidl.core.identifier_pb2.Identifier, T)
"""
resource_map = {
_identifier_pb2.LAUNCH_PLAN: _launch_plan_pb2.LaunchPlanSpec,
_identifier_pb2.WORKFLOW: _workflow_pb2.WorkflowSpec,
_identifier_pb2.TASK: _task_pb2.TaskSpec,
}
id = _load_proto_from_file(_identifier_pb2.Identifier, identifier_file)
if id.resource_type not in resource_map:
identifier = _load_proto_from_file(_identifier_pb2.Identifier, identifier_file)
if identifier.resource_type not in _resource_map:
raise _user_exceptions.FlyteAssertion(
f"Resource type found in identifier {id.resource_type} invalid, must be launch plan, " f"task, or workflow"
f"Resource type found in identifier {identifier.resource_type} invalid, must be launch plan, task, or workflow"
)
entity = _load_proto_from_file(resource_map[id.resource_type], object_file)
return id, entity
entity = _load_proto_from_file(_resource_map[identifier.resource_type], object_file)
registerable_identifier, registerable_entity = hydrate_registration_parameters(
identifier, project, domain, version, entity
)
return registerable_identifier, registerable_entity


def _extract_files(file_paths):
def _extract_files(project: str, domain: str, version: str, file_paths: List[str]):
"""
:param file_paths:
:rtype: List[(flyteidl.core.identifier_pb2.Identifier, T)]
Expand All @@ -1511,31 +1526,34 @@ def _extract_files(file_paths):
filename_iterator = iter(file_paths)
for identifier_file in filename_iterator:
object_file = next(filename_iterator)
id, entity = _extract_pair(identifier_file, object_file)
# Serialized proto files are of the form: 12_foo.bar.<resource_type>.pb
id, entity = _extract_pair(identifier_file, object_file, project, domain, version)
results.append((id, entity))

return results


@_flyte_cli.command("register-files", cls=_FlyteSubCommand)
@_click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to register with.")
@_click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to register with.")
@_click.option(*_VERSION_FLAGS, required=True, help="The entity version to register with")
@_host_option
@_insecure_option
@_click.argument(
"files", type=_click.Path(exists=True), nargs=-1,
)
def register_files(host, insecure, files):
def register_files(project, domain, version, host, insecure, files):
"""
Given a list of files, this will (after sorting the input list), attempt to register them against Flyte Admin.
This command expects the files to be the output of the pyflyte serialize command. See the code there for more
information. Valid files need to be:
information. Valid files need to be:\n
* Ordered in the order that you want registration to happen. pyflyte should have done the topological sort
for you and produced file that have a prefix that sets the correct order.
for you and produced file that have a prefix that sets the correct order.\n
* Of the correct type. That is, they should be the serialized form of one of these Flyte IDL objects
(or an identifier object).
- flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans
- flyteidl.admin.workflow_pb2.WorkflowSpec for workflows
- flyteidl.admin.task_pb2.TaskSpec for tasks
* Each file needs to be accompanied by an identifier file. We can relax this constraint in the future.
(or an identifier object).\n
- flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans\n
- flyteidl.admin.workflow_pb2.WorkflowSpec for workflows\n
- flyteidl.admin.task_pb2.TaskSpec for tasks\n
:param host:
:param insecure:
Expand All @@ -1550,7 +1568,7 @@ def register_files(host, insecure, files):
for f in files:
_click.echo(f" {f}")

flyte_entities_list = _extract_files(files)
flyte_entities_list = _extract_files(project, domain, version, files)
for id, flyte_entity in flyte_entities_list:
try:
if id.resource_type == _identifier_pb2.LAUNCH_PLAN:
Expand Down
79 changes: 79 additions & 0 deletions flytekit/clis/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Tuple

import six as _six
from flyteidl.core import identifier_pb2 as _identifier_pb2
from flyteidl.core import workflow_pb2 as _workflow_pb2
from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType as _GeneratedProtocolMessageType

from flytekit.common.types.helpers import get_sdk_type_from_literal_type as _get_sdk_type_from_literal_type
from flytekit.models import literals as _literals
Expand Down Expand Up @@ -79,3 +84,77 @@ def str2bool(str):
:rtype: bool
"""
return not str.lower() in ["false", "0", "off", "no"]


def _hydrate_identifier(
project: str, domain: str, version: str, identifier: _identifier_pb2.Identifier
) -> _identifier_pb2.Identifier:
identifier.project = identifier.project or project
identifier.domain = identifier.domain or domain
identifier.version = identifier.version or version
return identifier


def _hydrate_workflow_template(
project: str, domain: str, version: str, template: _workflow_pb2.WorkflowTemplate
) -> _workflow_pb2.WorkflowTemplate:
refreshed_nodes = []
for node in template.nodes:
if node.HasField("task_node"):
task_node = node.task_node
task_node.reference_id.CopyFrom(_hydrate_identifier(project, domain, version, task_node.reference_id))
node.task_node.CopyFrom(task_node)
elif node.HasField("workflow_node"):
workflow_node = node.workflow_node
if workflow_node.HasField("launchplan_ref"):
workflow_node.launchplan_ref.CopyFrom(
_hydrate_identifier(project, domain, version, workflow_node.launchplan_ref)
)
elif workflow_node.HasField("sub_workflow_ref"):
workflow_node.sub_workflow_ref.CopyFrom(
_hydrate_identifier(project, domain, version, workflow_node.sub_workflow_ref)
)
node.workflow_node.CopyFrom(workflow_node)
refreshed_nodes.append(node)
# Reassign nodes with the newly hydrated ones.
del template.nodes[:]
template.nodes.extend(refreshed_nodes)
return template


def hydrate_registration_parameters(
identifier: _identifier_pb2.Identifier,
project: str,
domain: str,
version: str,
entity: _GeneratedProtocolMessageType,
) -> Tuple[_identifier_pb2.Identifier, _GeneratedProtocolMessageType]:
"""
This is called at registration time to fill out identifier fields (e.g. project, domain, version) that are mutable.
Entity is one of \b
- flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans\n
- flyteidl.admin.workflow_pb2.WorkflowSpec for workflows\n
- flyteidl.admin.task_pb2.TaskSpec for tasks\n
"""
identifier = _hydrate_identifier(project, domain, version, identifier)

if identifier.resource_type == _identifier_pb2.LAUNCH_PLAN:
entity.workflow_id.CopyFrom(_hydrate_identifier(project, domain, version, entity.workflow_id))
return identifier, entity

entity.template.id.CopyFrom(identifier)
if identifier.resource_type == _identifier_pb2.TASK:
return identifier, entity

# Workflows (the only possible entity type at this point) are a little more complicated.
# Workflow nodes that are defined inline with the workflows will be missing project/domain/version so we fill those
# in now.
# (entity is of type flyteidl.admin.workflow_pb2.WorkflowSpec)
refreshed_sub_workflows = []
for sub_workflow in entity.sub_workflows:
refreshed_sub_workflow = _hydrate_workflow_template(project, domain, version, sub_workflow)
refreshed_sub_workflows.append(refreshed_sub_workflow)
# Reassign subworkflows with the newly hydrated ones.
del entity.sub_workflows[:]
entity.sub_workflows.extend(refreshed_sub_workflows)
return identifier, entity
14 changes: 14 additions & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import click as _click

CTX_PROJECT = "project"
CTX_DOMAIN = "domain"
CTX_VERSION = "version"
CTX_TEST = "test"
CTX_PACKAGES = "pkgs"
CTX_NOTIFICATIONS = "notifications"


project_option = _click.option(
"-p",
"--project",
required=True,
type=str,
help="Flyte project to use. You can have more than one project per repo",
)
domain_option = _click.option(
"-d", "--domain", required=True, type=str, help="This is usually development, staging, or production",
)
15 changes: 13 additions & 2 deletions flytekit/clis/sdk_in_container/launch_plan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging as _logging
import os as _os

import click
import six as _six

from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map
from flytekit.clis.sdk_in_container import constants as _constants
from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PROJECT, CTX_VERSION, domain_option, project_option
from flytekit.common import utils as _utils
from flytekit.common.launch_plan import SdkLaunchPlan as _SdkLaunchPlan
from flytekit.configuration.internal import DOMAIN as _DOMAIN
from flytekit.configuration.internal import IMAGE as _IMAGE
from flytekit.configuration.internal import PROJECT as _PROJECT
from flytekit.configuration.internal import VERSION as _VERSION
from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag
from flytekit.models import launch_plan as _launch_plan_model
from flytekit.models.core import identifier as _identifier
Expand Down Expand Up @@ -139,12 +144,18 @@ def _execute_lp(**kwargs):


@click.group("lp")
@project_option
@domain_option
@click.pass_context
def launch_plans(ctx):
def launch_plans(ctx, project, domain):
"""
Launch plan control group, including executions
"""
pass
ctx.obj[CTX_PROJECT] = project
ctx.obj[CTX_DOMAIN] = domain
_os.environ[_PROJECT.env_var] = project
_os.environ[_DOMAIN.env_var] = domain
_os.environ[_VERSION.env_var] = ctx.obj[CTX_VERSION]


@click.group("execute", cls=LaunchPlanExecuteGroup)
Expand Down
20 changes: 2 additions & 18 deletions flytekit/clis/sdk_in_container/pyflyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import click

from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_VERSION
from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES, CTX_VERSION
from flytekit.clis.sdk_in_container.launch_plan import launch_plans
from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.serialize import serialize
Expand All @@ -20,16 +20,6 @@


@click.group("pyflyte", invoke_without_command=True)
@click.option(
"-p",
"--project",
required=True,
type=str,
help="Flyte project to use. You can have more than one project per repo",
)
@click.option(
"-d", "--domain", required=True, type=str, help="This is usually development, staging, or production",
)
@click.option(
"-c", "--config", required=False, type=str, help="Path to config file for use within container",
)
Expand All @@ -48,7 +38,7 @@
"-i", "--insecure", required=False, type=bool, help="Do not use SSL to connect to Admin",
)
@click.pass_context
def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=None):
def main(ctx, config=None, pkgs=None, version=None, insecure=None):
"""
Entrypoint for all the user commands.
"""
Expand All @@ -60,8 +50,6 @@ def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=No
_logging.getLogger().setLevel(log_level)

ctx.obj = dict()
ctx.obj[CTX_PROJECT] = project
ctx.obj[CTX_DOMAIN] = domain
version = version or _look_up_version_from_image_tag(_IMAGE.get())
ctx.obj[CTX_VERSION] = version

Expand All @@ -78,10 +66,6 @@ def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=No
pkgs = _WORKFLOW_PACKAGES.get()
ctx.obj[CTX_PACKAGES] = pkgs

_os.environ[_internal_config.PROJECT.env_var] = project
_os.environ[_internal_config.DOMAIN.env_var] = domain
_os.environ[_internal_config.VERSION.env_var] = version


def update_configuration_file(config_file_path):
"""
Expand Down
23 changes: 21 additions & 2 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import logging as _logging
import os as _os

import click

from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_TEST, CTX_VERSION
from flytekit.clis.sdk_in_container.constants import (
CTX_DOMAIN,
CTX_PACKAGES,
CTX_PROJECT,
CTX_TEST,
CTX_VERSION,
domain_option,
project_option,
)
from flytekit.common import utils as _utils
from flytekit.common.core import identifier as _identifier
from flytekit.common.tasks import task as _task
from flytekit.configuration.internal import DOMAIN as _DOMAIN
from flytekit.configuration.internal import IMAGE as _IMAGE
from flytekit.configuration.internal import PROJECT as _PROJECT
from flytekit.configuration.internal import VERSION as _VERSION
from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag
from flytekit.tools.module_loader import iterate_registerable_entities_in_order

Expand Down Expand Up @@ -53,13 +65,15 @@ def register_tasks_only(project, domain, pkgs, test, version):


@click.group("register")
@project_option
@domain_option
# --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead
@click.option(
"--pkgs", multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword",
)
@click.option("--test", is_flag=True, help="Dry run, do not actually register with Admin")
@click.pass_context
def register(ctx, pkgs=None, test=None):
def register(ctx, project, domain, pkgs=None, test=None):
"""
Run registration steps for the workflows in this container.
Expand All @@ -69,7 +83,12 @@ def register(ctx, pkgs=None, test=None):
if pkgs:
raise click.UsageError("--pkgs must now be specified before the 'register' keyword on the command line")

ctx.obj[CTX_PROJECT] = project
ctx.obj[CTX_DOMAIN] = domain
ctx.obj[CTX_TEST] = test
_os.environ[_PROJECT.env_var] = project
_os.environ[_DOMAIN.env_var] = domain
_os.environ[_VERSION.env_var] = ctx.obj[CTX_VERSION]


@click.command("tasks")
Expand Down
Loading

0 comments on commit 812aca4

Please sign in to comment.