diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 49f4a9cf94..a1fc31b173 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -2,6 +2,7 @@ import os as _os import stat as _stat import sys as _sys +from typing import List, Tuple, Union import click as _click import requests as _requests @@ -11,11 +12,14 @@ from flyteidl.admin import workflow_pb2 as _workflow_pb2 from flyteidl.core import identifier_pb2 as _identifier_pb2 from flyteidl.core import literals_pb2 as _literals_pb2 +from flyteidl.core import tasks_pb2 as _core_tasks_pb2 +from flyteidl.core import workflow_pb2 as _core_workflow_pb2 from flytekit import __version__ from flytekit.clients import friendly as _friendly_client from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map +from flytekit.clis.helpers import hydrate_registration_parameters from flytekit.clis.helpers import parse_args_into_dict as _parse_args_into_dict from flytekit.common import launch_plan as _launch_plan_common from flytekit.common import utils as _utils @@ -286,6 +290,7 @@ def _render_schedule_expr(lp): _PROJECT_FLAGS = ["-p", "--project"] _DOMAIN_FLAGS = ["-d", "--domain"] _NAME_FLAGS = ["-n", "--name"] +_VERSION_FLAGS = ["-v", "--version"] _HOST_FLAGS = ["-h", "--host"] _PRINCIPAL_FLAGS = ["-r", "--principal"] _INSECURE_FLAGS = ["-i", "--insecure"] @@ -1478,27 +1483,37 @@ def register_project(identifier, name, description, host, insecure): _click.echo("Registered project [id: {}, name: {}, description: {}]".format(identifier, name, description)) -def _extract_pair(identifier_file, object_file): +_resource_map = { + _identifier_pb2.LAUNCH_PLAN: _launch_plan_pb2.LaunchPlanSpec, + _identifier_pb2.WORKFLOW: _workflow_pb2.WorkflowSpec, + _identifier_pb2.TASK: _task_pb2.TaskSpec, +} + + +def _extract_pair( + identifier_file: str, object_file: str, project: str, domain: str, version: str +) -> Tuple[ + _identifier_pb2.Identifier, + Union[_core_tasks_pb2.TaskTemplate, _core_workflow_pb2.WorkflowTemplate, _launch_plan_pb2.LaunchPlanSpec], +]: """ :param Text identifier_file: :param Text object_file: :rtype: (flyteidl.core.identifier_pb2.Identifier, T) """ - resource_map = { - _identifier_pb2.LAUNCH_PLAN: _launch_plan_pb2.LaunchPlanSpec, - _identifier_pb2.WORKFLOW: _workflow_pb2.WorkflowSpec, - _identifier_pb2.TASK: _task_pb2.TaskSpec, - } - id = _load_proto_from_file(_identifier_pb2.Identifier, identifier_file) - if id.resource_type not in resource_map: + identifier = _load_proto_from_file(_identifier_pb2.Identifier, identifier_file) + if identifier.resource_type not in _resource_map: raise _user_exceptions.FlyteAssertion( - f"Resource type found in identifier {id.resource_type} invalid, must be launch plan, " f"task, or workflow" + f"Resource type found in identifier {identifier.resource_type} invalid, must be launch plan, task, or workflow" ) - entity = _load_proto_from_file(resource_map[id.resource_type], object_file) - return id, entity + entity = _load_proto_from_file(_resource_map[identifier.resource_type], object_file) + registerable_identifier, registerable_entity = hydrate_registration_parameters( + identifier, project, domain, version, entity + ) + return registerable_identifier, registerable_entity -def _extract_files(file_paths): +def _extract_files(project: str, domain: str, version: str, file_paths: List[str]): """ :param file_paths: :rtype: List[(flyteidl.core.identifier_pb2.Identifier, T)] @@ -1511,31 +1526,34 @@ def _extract_files(file_paths): filename_iterator = iter(file_paths) for identifier_file in filename_iterator: object_file = next(filename_iterator) - id, entity = _extract_pair(identifier_file, object_file) + # Serialized proto files are of the form: 12_foo.bar..pb + id, entity = _extract_pair(identifier_file, object_file, project, domain, version) results.append((id, entity)) return results @_flyte_cli.command("register-files", cls=_FlyteSubCommand) +@_click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to register with.") +@_click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to register with.") +@_click.option(*_VERSION_FLAGS, required=True, help="The entity version to register with") @_host_option @_insecure_option @_click.argument( "files", type=_click.Path(exists=True), nargs=-1, ) -def register_files(host, insecure, files): +def register_files(project, domain, version, host, insecure, files): """ Given a list of files, this will (after sorting the input list), attempt to register them against Flyte Admin. This command expects the files to be the output of the pyflyte serialize command. See the code there for more - information. Valid files need to be: + information. Valid files need to be:\n * Ordered in the order that you want registration to happen. pyflyte should have done the topological sort - for you and produced file that have a prefix that sets the correct order. + for you and produced file that have a prefix that sets the correct order.\n * Of the correct type. That is, they should be the serialized form of one of these Flyte IDL objects - (or an identifier object). - - flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans - - flyteidl.admin.workflow_pb2.WorkflowSpec for workflows - - flyteidl.admin.task_pb2.TaskSpec for tasks - * Each file needs to be accompanied by an identifier file. We can relax this constraint in the future. + (or an identifier object).\n + - flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans\n + - flyteidl.admin.workflow_pb2.WorkflowSpec for workflows\n + - flyteidl.admin.task_pb2.TaskSpec for tasks\n :param host: :param insecure: @@ -1550,7 +1568,7 @@ def register_files(host, insecure, files): for f in files: _click.echo(f" {f}") - flyte_entities_list = _extract_files(files) + flyte_entities_list = _extract_files(project, domain, version, files) for id, flyte_entity in flyte_entities_list: try: if id.resource_type == _identifier_pb2.LAUNCH_PLAN: diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index c38baa638a..11573bc7a6 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -1,4 +1,9 @@ +from typing import Tuple + import six as _six +from flyteidl.core import identifier_pb2 as _identifier_pb2 +from flyteidl.core import workflow_pb2 as _workflow_pb2 +from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType as _GeneratedProtocolMessageType from flytekit.common.types.helpers import get_sdk_type_from_literal_type as _get_sdk_type_from_literal_type from flytekit.models import literals as _literals @@ -79,3 +84,77 @@ def str2bool(str): :rtype: bool """ return not str.lower() in ["false", "0", "off", "no"] + + +def _hydrate_identifier( + project: str, domain: str, version: str, identifier: _identifier_pb2.Identifier +) -> _identifier_pb2.Identifier: + identifier.project = identifier.project or project + identifier.domain = identifier.domain or domain + identifier.version = identifier.version or version + return identifier + + +def _hydrate_workflow_template( + project: str, domain: str, version: str, template: _workflow_pb2.WorkflowTemplate +) -> _workflow_pb2.WorkflowTemplate: + refreshed_nodes = [] + for node in template.nodes: + if node.HasField("task_node"): + task_node = node.task_node + task_node.reference_id.CopyFrom(_hydrate_identifier(project, domain, version, task_node.reference_id)) + node.task_node.CopyFrom(task_node) + elif node.HasField("workflow_node"): + workflow_node = node.workflow_node + if workflow_node.HasField("launchplan_ref"): + workflow_node.launchplan_ref.CopyFrom( + _hydrate_identifier(project, domain, version, workflow_node.launchplan_ref) + ) + elif workflow_node.HasField("sub_workflow_ref"): + workflow_node.sub_workflow_ref.CopyFrom( + _hydrate_identifier(project, domain, version, workflow_node.sub_workflow_ref) + ) + node.workflow_node.CopyFrom(workflow_node) + refreshed_nodes.append(node) + # Reassign nodes with the newly hydrated ones. + del template.nodes[:] + template.nodes.extend(refreshed_nodes) + return template + + +def hydrate_registration_parameters( + identifier: _identifier_pb2.Identifier, + project: str, + domain: str, + version: str, + entity: _GeneratedProtocolMessageType, +) -> Tuple[_identifier_pb2.Identifier, _GeneratedProtocolMessageType]: + """ + This is called at registration time to fill out identifier fields (e.g. project, domain, version) that are mutable. + Entity is one of \b + - flyteidl.admin.launch_plan_pb2.LaunchPlanSpec for launch plans\n + - flyteidl.admin.workflow_pb2.WorkflowSpec for workflows\n + - flyteidl.admin.task_pb2.TaskSpec for tasks\n + """ + identifier = _hydrate_identifier(project, domain, version, identifier) + + if identifier.resource_type == _identifier_pb2.LAUNCH_PLAN: + entity.workflow_id.CopyFrom(_hydrate_identifier(project, domain, version, entity.workflow_id)) + return identifier, entity + + entity.template.id.CopyFrom(identifier) + if identifier.resource_type == _identifier_pb2.TASK: + return identifier, entity + + # Workflows (the only possible entity type at this point) are a little more complicated. + # Workflow nodes that are defined inline with the workflows will be missing project/domain/version so we fill those + # in now. + # (entity is of type flyteidl.admin.workflow_pb2.WorkflowSpec) + refreshed_sub_workflows = [] + for sub_workflow in entity.sub_workflows: + refreshed_sub_workflow = _hydrate_workflow_template(project, domain, version, sub_workflow) + refreshed_sub_workflows.append(refreshed_sub_workflow) + # Reassign subworkflows with the newly hydrated ones. + del entity.sub_workflows[:] + entity.sub_workflows.extend(refreshed_sub_workflows) + return identifier, entity diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 34dc4460bd..dec90bda5e 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -1,6 +1,20 @@ +import click as _click + CTX_PROJECT = "project" CTX_DOMAIN = "domain" CTX_VERSION = "version" CTX_TEST = "test" CTX_PACKAGES = "pkgs" CTX_NOTIFICATIONS = "notifications" + + +project_option = _click.option( + "-p", + "--project", + required=True, + type=str, + help="Flyte project to use. You can have more than one project per repo", +) +domain_option = _click.option( + "-d", "--domain", required=True, type=str, help="This is usually development, staging, or production", +) diff --git a/flytekit/clis/sdk_in_container/launch_plan.py b/flytekit/clis/sdk_in_container/launch_plan.py index b8d1f18cef..3f93f3e99b 100644 --- a/flytekit/clis/sdk_in_container/launch_plan.py +++ b/flytekit/clis/sdk_in_container/launch_plan.py @@ -1,13 +1,18 @@ import logging as _logging +import os as _os import click import six as _six from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map from flytekit.clis.sdk_in_container import constants as _constants +from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PROJECT, CTX_VERSION, domain_option, project_option from flytekit.common import utils as _utils from flytekit.common.launch_plan import SdkLaunchPlan as _SdkLaunchPlan +from flytekit.configuration.internal import DOMAIN as _DOMAIN from flytekit.configuration.internal import IMAGE as _IMAGE +from flytekit.configuration.internal import PROJECT as _PROJECT +from flytekit.configuration.internal import VERSION as _VERSION from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag from flytekit.models import launch_plan as _launch_plan_model from flytekit.models.core import identifier as _identifier @@ -139,12 +144,18 @@ def _execute_lp(**kwargs): @click.group("lp") +@project_option +@domain_option @click.pass_context -def launch_plans(ctx): +def launch_plans(ctx, project, domain): """ Launch plan control group, including executions """ - pass + ctx.obj[CTX_PROJECT] = project + ctx.obj[CTX_DOMAIN] = domain + _os.environ[_PROJECT.env_var] = project + _os.environ[_DOMAIN.env_var] = domain + _os.environ[_VERSION.env_var] = ctx.obj[CTX_VERSION] @click.group("execute", cls=LaunchPlanExecuteGroup) diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 93da6fd54e..ea69537bd3 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -4,7 +4,7 @@ import click -from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_VERSION +from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES, CTX_VERSION from flytekit.clis.sdk_in_container.launch_plan import launch_plans from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.serialize import serialize @@ -20,16 +20,6 @@ @click.group("pyflyte", invoke_without_command=True) -@click.option( - "-p", - "--project", - required=True, - type=str, - help="Flyte project to use. You can have more than one project per repo", -) -@click.option( - "-d", "--domain", required=True, type=str, help="This is usually development, staging, or production", -) @click.option( "-c", "--config", required=False, type=str, help="Path to config file for use within container", ) @@ -48,7 +38,7 @@ "-i", "--insecure", required=False, type=bool, help="Do not use SSL to connect to Admin", ) @click.pass_context -def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=None): +def main(ctx, config=None, pkgs=None, version=None, insecure=None): """ Entrypoint for all the user commands. """ @@ -60,8 +50,6 @@ def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=No _logging.getLogger().setLevel(log_level) ctx.obj = dict() - ctx.obj[CTX_PROJECT] = project - ctx.obj[CTX_DOMAIN] = domain version = version or _look_up_version_from_image_tag(_IMAGE.get()) ctx.obj[CTX_VERSION] = version @@ -78,10 +66,6 @@ def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=No pkgs = _WORKFLOW_PACKAGES.get() ctx.obj[CTX_PACKAGES] = pkgs - _os.environ[_internal_config.PROJECT.env_var] = project - _os.environ[_internal_config.DOMAIN.env_var] = domain - _os.environ[_internal_config.VERSION.env_var] = version - def update_configuration_file(config_file_path): """ diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 45e4ebd1fb..aa6c606477 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -1,12 +1,24 @@ import logging as _logging +import os as _os import click -from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_TEST, CTX_VERSION +from flytekit.clis.sdk_in_container.constants import ( + CTX_DOMAIN, + CTX_PACKAGES, + CTX_PROJECT, + CTX_TEST, + CTX_VERSION, + domain_option, + project_option, +) from flytekit.common import utils as _utils from flytekit.common.core import identifier as _identifier from flytekit.common.tasks import task as _task +from flytekit.configuration.internal import DOMAIN as _DOMAIN from flytekit.configuration.internal import IMAGE as _IMAGE +from flytekit.configuration.internal import PROJECT as _PROJECT +from flytekit.configuration.internal import VERSION as _VERSION from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag from flytekit.tools.module_loader import iterate_registerable_entities_in_order @@ -53,13 +65,15 @@ def register_tasks_only(project, domain, pkgs, test, version): @click.group("register") +@project_option +@domain_option # --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead @click.option( "--pkgs", multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword", ) @click.option("--test", is_flag=True, help="Dry run, do not actually register with Admin") @click.pass_context -def register(ctx, pkgs=None, test=None): +def register(ctx, project, domain, pkgs=None, test=None): """ Run registration steps for the workflows in this container. @@ -69,7 +83,12 @@ def register(ctx, pkgs=None, test=None): if pkgs: raise click.UsageError("--pkgs must now be specified before the 'register' keyword on the command line") + ctx.obj[CTX_PROJECT] = project + ctx.obj[CTX_DOMAIN] = domain ctx.obj[CTX_TEST] = test + _os.environ[_PROJECT.env_var] = project + _os.environ[_DOMAIN.env_var] = domain + _os.environ[_VERSION.env_var] = ctx.obj[CTX_VERSION] @click.command("tasks") diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index 95a51340cb..08f323c815 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -8,25 +8,25 @@ from flytekit.annotated.base_task import PythonTask from flytekit.annotated.launch_plan import LaunchPlan from flytekit.annotated.workflow import Workflow -from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_VERSION +from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES from flytekit.common import utils as _utils from flytekit.common.core import identifier as _identifier from flytekit.common.exceptions.scopes import system_entry_point from flytekit.common.tasks import task as _sdk_task from flytekit.common.utils import write_proto_to_file as _write_proto_to_file -from flytekit.configuration import TemporaryConfiguration from flytekit.configuration import auth as _auth_config from flytekit.configuration import internal as _internal_config from flytekit.tools.module_loader import iterate_registerable_entities_in_order +_PROJECT_PLACEHOLDER = "" +_DOMAIN_PLACEHOLDER = "" +_VERSION_PLACEHOLDER = "" + @system_entry_point -def serialize_tasks_only(project, domain, pkgs, version, folder=None): +def serialize_tasks_only(pkgs, folder=None): """ - :param Text project: - :param Text domain: :param list[Text] pkgs: - :param Text version: :param Text folder: :return: @@ -38,7 +38,9 @@ def serialize_tasks_only(project, domain, pkgs, version, folder=None): for m, k, o in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier(o.resource_type, project, domain, name, version) + o._id = _identifier.Identifier( + o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER + ) loaded_entities.append(o) zero_padded_length = _determine_text_chars(len(loaded_entities)) @@ -58,7 +60,7 @@ def serialize_tasks_only(project, domain, pkgs, version, folder=None): @system_entry_point -def serialize_all(project, domain, pkgs, version, folder=None): +def serialize_all(pkgs, folder=None): """ In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These flyteidl.admin.launch_plan_pb2.LaunchPlanSpec @@ -72,10 +74,7 @@ def serialize_all(project, domain, pkgs, version, folder=None): For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects. - :param Text project: - :param Text domain: :param list[Text] pkgs: - :param Text version: :param Text folder: :return: @@ -90,9 +89,9 @@ def serialize_all(project, domain, pkgs, version, folder=None): } registration_settings = flyte_context.RegistrationSettings( - project=project, - domain=domain, - version=version, + project=_PROJECT_PLACEHOLDER, + domain=_DOMAIN_PLACEHOLDER, + version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(), env=env, iam_role=_auth_config.ASSUMABLE_IAM_ROLE.get(), @@ -106,7 +105,9 @@ def serialize_all(project, domain, pkgs, version, folder=None): for m, k, o in iterate_registerable_entities_in_order(pkgs): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier(o.resource_type, project, domain, name, version) + o._id = _identifier.Identifier( + o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER + ) loaded_entities.append(o) click.echo(f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows") @@ -138,7 +139,7 @@ def serialize_all(project, domain, pkgs, version, folder=None): serialized = entity.serialize() fname_index = str(i).zfill(zero_padded_length) fname = "{}_{}.pb".format(fname_index, entity.id.name) - click.echo(f" Writing type: {entity.id.resource_type}, {entity.id.name} to\n {fname}") + click.echo(f" Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n {fname}") if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(serialized, fname) @@ -181,85 +182,30 @@ def serialize(ctx): @click.command("tasks") -@click.option( - "-v", - "--version", - type=str, - help="Version to serialize tasks with. This is normally parsed from the" "image, but you can override here.", -) @click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context -def tasks(ctx, version=None, folder=None): - project = ctx.obj[CTX_PROJECT] - domain = ctx.obj[CTX_DOMAIN] +def tasks(ctx, folder=None): pkgs = ctx.obj[CTX_PACKAGES] if folder: click.echo(f"Writing output to {folder}") - version = ( - version or ctx.obj[CTX_VERSION] or _internal_config.look_up_version_from_image_tag(_internal_config.IMAGE.get()) - ) - - internal_settings = { - "project": project, - "domain": domain, - "version": version, - } - # Populate internal settings for project/domain/version from the environment so that the file names are resolved - # with the correct strings. The file itself doesn't need to change though. - with TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get(), internal_settings): - _logging.debug( - "Serializing with settings\n" - "\n Project: {}" - "\n Domain: {}" - "\n Version: {}" - "\n\nover the following packages {}".format(project, domain, version, pkgs) - ) - serialize_tasks_only(project, domain, pkgs, version, folder) + serialize_tasks_only(pkgs, folder) @click.command("workflows") -@click.option( - "-v", - "--version", - type=str, - help="Version to serialize tasks with. This is normally parsed from the" "image, but you can override here.", -) # For now let's just assume that the directory needs to exist. If you're docker run -v'ing, docker will create the # directory for you so it shouldn't be a problem. @click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context -def workflows(ctx, version=None, folder=None): +def workflows(ctx, folder=None): _logging.getLogger().setLevel(_logging.DEBUG) if folder: click.echo(f"Writing output to {folder}") - project = ctx.obj[CTX_PROJECT] - domain = ctx.obj[CTX_DOMAIN] pkgs = ctx.obj[CTX_PACKAGES] - - version = ( - version or ctx.obj[CTX_VERSION] or _internal_config.look_up_version_from_image_tag(_internal_config.IMAGE.get()) - ) - - internal_settings = { - "project": project, - "domain": domain, - "version": version, - } - # Populate internal settings for project/domain/version from the environment so that the file names are resolved - # with the correct strings. The file itself doesn't need to change though. - with TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get(), internal_settings): - _logging.debug( - "Serializing with settings\n" - "\n Project: {}" - "\n Domain: {}" - "\n Version: {}" - "\n\nover the following packages {}".format(project, domain, version, pkgs) - ) - serialize_all(project, domain, pkgs, version, folder) + serialize_all(pkgs, folder) serialize.add_command(tasks) diff --git a/flytekit/models/core/identifier.py b/flytekit/models/core/identifier.py index 17e27644e6..0fa6a87594 100644 --- a/flytekit/models/core/identifier.py +++ b/flytekit/models/core/identifier.py @@ -33,6 +33,9 @@ def resource_type(self): """ return self._resource_type + def resource_type_name(self) -> str: + return _identifier_pb2.ResourceType.Name(self.resource_type) + @property def project(self): """ diff --git a/tests/flytekit/unit/cli/test_helpers.py b/tests/flytekit/unit/cli/test_helpers.py index 10022778e5..a4c819c4eb 100644 --- a/tests/flytekit/unit/cli/test_helpers.py +++ b/tests/flytekit/unit/cli/test_helpers.py @@ -1,6 +1,13 @@ +import flyteidl.admin.launch_plan_pb2 as _launch_plan_pb2 +import flyteidl.admin.task_pb2 as _task_pb2 +import flyteidl.admin.workflow_pb2 as _workflow_pb2 +import flyteidl.core.tasks_pb2 as _core_task_pb2 import pytest +from flyteidl.core import identifier_pb2 as _identifier_pb2 +from flyteidl.core import workflow_pb2 as _core_workflow_pb2 from flytekit.clis import helpers +from flytekit.clis.helpers import _hydrate_identifier, _hydrate_workflow_template, hydrate_registration_parameters from flytekit.models import literals, types from flytekit.models.interface import Parameter, ParameterMap, Variable @@ -58,3 +65,212 @@ def test_strtobool(): assert helpers.str2bool("t") assert helpers.str2bool("true") assert helpers.str2bool("stuff") + + +def test_hydrate_identifier(): + identifier = _hydrate_identifier("project", "domain", "12345", _identifier_pb2.Identifier()) + assert identifier.project == "project" + assert identifier.domain == "domain" + assert identifier.version == "12345" + + identifier = _hydrate_identifier( + "project2", "domain2", "abc", _identifier_pb2.Identifier(project="project", domain="domain", version="12345") + ) + assert identifier.project == "project" + assert identifier.domain == "domain" + assert identifier.version == "12345" + + +def test_hydrate_workflow_template(): + workflow_template = _core_workflow_pb2.WorkflowTemplate() + workflow_template.nodes.append( + _core_workflow_pb2.Node( + id="task_node", + task_node=_core_workflow_pb2.TaskNode( + reference_id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.TASK) + ), + ) + ) + workflow_template.nodes.append( + _core_workflow_pb2.Node( + id="launchplan_ref", + workflow_node=_core_workflow_pb2.WorkflowNode( + launchplan_ref=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.LAUNCH_PLAN, project="project2", + ) + ), + ) + ) + workflow_template.nodes.append( + _core_workflow_pb2.Node( + id="sub_workflow_ref", + workflow_node=_core_workflow_pb2.WorkflowNode( + sub_workflow_ref=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, project="project2", domain="domain2", + ) + ), + ) + ) + workflow_template.nodes.append( + _core_workflow_pb2.Node( + id="unchanged", + task_node=_core_workflow_pb2.TaskNode( + reference_id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.TASK, project="project2", domain="domain2", version="abc" + ) + ), + ) + ) + hydrated_workflow_template = _hydrate_workflow_template("project", "domain", "12345", workflow_template) + assert len(hydrated_workflow_template.nodes) == 4 + task_node_identifier = hydrated_workflow_template.nodes[0].task_node.reference_id + assert task_node_identifier.project == "project" + assert task_node_identifier.domain == "domain" + assert task_node_identifier.version == "12345" + + launchplan_ref_identifier = hydrated_workflow_template.nodes[1].workflow_node.launchplan_ref + assert launchplan_ref_identifier.project == "project2" + assert launchplan_ref_identifier.domain == "domain" + assert launchplan_ref_identifier.version == "12345" + + sub_workflow_ref_identifier = hydrated_workflow_template.nodes[2].workflow_node.sub_workflow_ref + assert sub_workflow_ref_identifier.project == "project2" + assert sub_workflow_ref_identifier.domain == "domain2" + assert sub_workflow_ref_identifier.version == "12345" + + unchanged_identifier = hydrated_workflow_template.nodes[3].task_node.reference_id + assert unchanged_identifier.project == "project2" + assert unchanged_identifier.domain == "domain2" + assert unchanged_identifier.version == "abc" + + +def test_hydrate_registration_parameters__launch_plan_already_set(): + launch_plan = _launch_plan_pb2.LaunchPlanSpec( + workflow_id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, + project="project2", + domain="domain2", + name="workflow_name", + version="abc", + ) + ) + identifier, entity = hydrate_registration_parameters( + _identifier_pb2.Identifier( + resource_type=_identifier_pb2.LAUNCH_PLAN, + project="project2", + domain="domain2", + name="workflow_name", + version="abc", + ), + "project", + "domain", + "12345", + launch_plan, + ) + assert identifier == _identifier_pb2.Identifier( + resource_type=_identifier_pb2.LAUNCH_PLAN, + project="project2", + domain="domain2", + name="workflow_name", + version="abc", + ) + assert entity.workflow_id == launch_plan.workflow_id + + +def test_hydrate_registration_parameters__launch_plan_nothing_set(): + launch_plan = _launch_plan_pb2.LaunchPlanSpec( + workflow_id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="workflow_name",) + ) + identifier, entity = hydrate_registration_parameters( + _identifier_pb2.Identifier(resource_type=_identifier_pb2.LAUNCH_PLAN, name="workflow_name"), + "project", + "domain", + "12345", + launch_plan, + ) + assert identifier == _identifier_pb2.Identifier( + resource_type=_identifier_pb2.LAUNCH_PLAN, + project="project", + domain="domain", + name="workflow_name", + version="12345", + ) + assert entity.workflow_id == _identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, + project="project", + domain="domain", + name="workflow_name", + version="12345", + ) + + +def test_hydrate_registration_parameters__task_already_set(): + task = _task_pb2.TaskSpec( + template=_core_task_pb2.TaskTemplate( + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.TASK, project="project2", domain="domain2", name="name", version="abc", + ), + ) + ) + identifier, entity = hydrate_registration_parameters(task.template.id, "project", "domain", "12345", task) + assert ( + identifier + == _identifier_pb2.Identifier( + resource_type=_identifier_pb2.TASK, project="project2", domain="domain2", name="name", version="abc", + ) + == entity.template.id + ) + + +def test_hydrate_registration_parameters__task_nothing_set(): + task = _task_pb2.TaskSpec( + template=_core_task_pb2.TaskTemplate( + id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.TASK, name="name",), + ) + ) + identifier, entity = hydrate_registration_parameters(task.template.id, "project", "domain", "12345", task) + assert ( + identifier + == _identifier_pb2.Identifier( + resource_type=_identifier_pb2.TASK, project="project", domain="domain", name="name", version="12345", + ) + == entity.template.id + ) + + +def test_hydrate_registration_parameters__workflow_already_set(): + workflow = _workflow_pb2.WorkflowSpec( + template=_core_workflow_pb2.WorkflowTemplate( + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, + project="project2", + domain="domain2", + name="name", + version="abc", + ), + ) + ) + identifier, entity = hydrate_registration_parameters(workflow.template.id, "project", "domain", "12345", workflow) + assert ( + identifier + == _identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, project="project2", domain="domain2", name="name", version="abc", + ) + == entity.template.id + ) + + +def test_hydrate_registration_parameters__workflow_nothing_set(): + workflow = _workflow_pb2.WorkflowSpec( + template=_core_workflow_pb2.WorkflowTemplate( + id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="name",), + ) + ) + identifier, entity = hydrate_registration_parameters(workflow.template.id, "project", "domain", "12345", workflow) + assert ( + identifier + == _identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, project="project", domain="domain", name="name", version="12345", + ) + == entity.template.id + ) diff --git a/tests/flytekit/unit/models/core/test_identifier.py b/tests/flytekit/unit/models/core/test_identifier.py index 75ae18cf23..bf00aed216 100644 --- a/tests/flytekit/unit/models/core/test_identifier.py +++ b/tests/flytekit/unit/models/core/test_identifier.py @@ -8,6 +8,7 @@ def test_identifier(): assert obj.name == "name" assert obj.version == "version" assert obj.resource_type == identifier.ResourceType.TASK + assert obj.resource_type_name() == "TASK" obj2 = identifier.Identifier.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj @@ -16,6 +17,7 @@ def test_identifier(): assert obj2.name == "name" assert obj2.version == "version" assert obj2.resource_type == identifier.ResourceType.TASK + assert obj2.resource_type_name() == "TASK" def test_node_execution_identifier(): @@ -64,3 +66,4 @@ def test_identifier_emptiness(): not_empty_id = identifier.Identifier(identifier.ResourceType.UNSPECIFIED, "", "", "", "version") assert empty_id.is_empty assert not not_empty_id.is_empty + assert not_empty_id.resource_type_name() == "UNSPECIFIED"