diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 3a165084b0..369316c642 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -160,7 +160,6 @@ else: from importlib.metadata import entry_points -import flytekit.plugins # This will be deprecated, these are the old plugins, the new plugins live in plugins/ from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes from flytekit.core.condition import conditional diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 1888376e50..754eab666d 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,10 +1,8 @@ import contextlib import datetime as _datetime -import importlib as _importlib import logging as python_logging import os as _os import pathlib -import random as _random import traceback as _traceback from typing import List @@ -12,18 +10,14 @@ from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit import PythonFunctionTask -from flytekit.common import constants as _constants -from flytekit.common import utils as _common_utils -from flytekit.common import utils as _utils -from flytekit.common.exceptions import scopes as _scoped_exceptions -from flytekit.common.exceptions import scopes as _scopes -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration from flytekit.configuration import internal as _internal_config from flytekit.configuration import sdk as _sdk_config +from flytekit.core import constants as _constants +from flytekit.core import utils from flytekit.core.base_task import IgnoreOutputs, PythonTask from flytekit.core.context_manager import ( + ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager, @@ -33,9 +27,8 @@ from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.map_task import MapPythonTask from flytekit.core.promise import VoidPromise -from flytekit.engines import loader as _engine_loader -from flytekit.interfaces import random as _flyte_random -from flytekit.interfaces.data import data_proxy as _data_proxy +from flytekit.exceptions import scopes as _scoped_exceptions +from flytekit.exceptions import scopes as _scopes from flytekit.interfaces.stats.taggable import get_stats as _get_stats from flytekit.loggers import entrypoint_logger as logger from flytekit.models import dynamic_job as _dynamic_job @@ -47,6 +40,12 @@ from flytekit.tools.module_loader import load_object_from_module +def get_version_message(): + import flytekit + + return f"Welcome to Flyte! Version: {flytekit.__version__}" + + def _compute_array_job_index(): # type () -> int """ @@ -61,25 +60,6 @@ def _compute_array_job_index(): return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))) -def _map_job_index_to_child_index(local_input_dir, datadir, index): - local_lookup_file = local_input_dir.get_named_tempfile("indexlookup.pb") - idx_lookup_file = _os.path.join(datadir, "indexlookup.pb") - - # if the indexlookup.pb does not exist, then just return the index - if not _data_proxy.Data.data_exists(idx_lookup_file): - return index - - _data_proxy.Data.get_data(idx_lookup_file, local_lookup_file) - mapping_proto = _utils.load_proto_from_file(_literals_pb2.LiteralCollection, local_lookup_file) - if len(mapping_proto.literals) < index: - raise _system_exceptions.FlyteSystemAssertion( - "dynamic task index lookup array size: {} is smaller than lookup index {}".format( - len(mapping_proto.literals), index - ) - ) - return mapping_proto.literals[index].scalar.primitive.integer - - def _dispatch_execute( ctx: FlyteContext, task_def: PythonTask, @@ -101,7 +81,7 @@ def _dispatch_execute( # Step1 local_inputs_file = _os.path.join(ctx.execution_state.working_dir, "inputs.pb") ctx.file_access.get_data(inputs_path, local_inputs_file) - input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) + input_proto = utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto) # Step2 @@ -174,7 +154,7 @@ def _dispatch_execute( logger.error("!! End Error Captured by Flyte !!") for k, v in output_file_dict.items(): - _common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(ctx.execution_state.engine_dir, k)) + utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(ctx.execution_state.engine_dir, k)) ctx.file_access.put_data(ctx.execution_state.engine_dir, output_prefix, is_multipart=True) logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") @@ -283,45 +263,6 @@ def _handle_annotated_task( _dispatch_execute(ctx, task_def, inputs, output_prefix) -@_scopes.system_entry_point -def _legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test): - """ - This function should be called for old flytekit api tasks (the only API that was available in 0.15.x and earlier) - """ - with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - with _utils.AutoDeletingTempDir("input_dir") as input_dir: - # Load user code - task_module = _importlib.import_module(task_module) - task_def = getattr(task_module, task_name) - - local_inputs_file = input_dir.get_named_tempfile("inputs.pb") - - # Handle inputs/outputs for array job. - if _os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"): - job_index = _compute_array_job_index() - - # TODO: Perhaps remove. This is a workaround to an issue we perceived with limited entropy in - # TODO: AWS batch array jobs. - _flyte_random.seed_flyte_random( - "{} {} {}".format(_random.random(), _datetime.datetime.utcnow(), job_index) - ) - - # If an ArrayTask is discoverable, the original job index may be different than the one specified in - # the environment variable. Look up the correct input/outputs in the index lookup mapping file. - job_index = _map_job_index_to_child_index(input_dir, inputs, job_index) - - inputs = _os.path.join(inputs, str(job_index), "inputs.pb") - output_prefix = _os.path.join(output_prefix, str(job_index)) - - _data_proxy.Data.get_data(inputs, local_inputs_file) - input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) - - _engine_loader.get_engine().get_task(task_def).execute( - _literal_models.LiteralMap.from_flyte_idl(input_proto), - context={"output_prefix": output_prefix, "raw_output_data_prefix": raw_output_data_prefix}, - ) - - @_scopes.system_entry_point def _execute_task( inputs, @@ -416,8 +357,6 @@ def _pass_through(): @_pass_through.command("pyflyte-execute") -@_click.option("--task-module", required=False) -@_click.option("--task-name", required=False) @_click.option("--inputs", required=True) @_click.option("--output-prefix", required=True) @_click.option("--raw-output-data-prefix", required=False) @@ -431,8 +370,6 @@ def _pass_through(): nargs=-1, ) def execute_task_cmd( - task_module, - task_name, inputs, output_prefix, raw_output_data_prefix, @@ -442,7 +379,7 @@ def execute_task_cmd( resolver, resolver_args, ): - logger.info(_utils.get_version_message()) + logger.info(get_version_message()) # We get weird errors if there are no click echo messages at all, so emit an empty string so that unit tests pass. _click.echo("") # Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original @@ -455,21 +392,17 @@ def execute_task_cmd( # Use the presence of the resolver to differentiate between old API tasks and new API tasks # The addition of a new top-level command seemed out of scope at the time of this writing to pursue given how # pervasive this top level command already (plugins mostly). - if not resolver: - logger.info("No resolver found, assuming legacy API task...") - _legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test) - else: - logger.debug(f"Running task execution with resolver {resolver}...") - _execute_task( - inputs, - output_prefix, - raw_output_data_prefix, - test, - resolver, - resolver_args, - dynamic_addl_distro, - dynamic_dest_dir, - ) + logger.debug(f"Running task execution with resolver {resolver}...") + _execute_task( + inputs, + output_prefix, + raw_output_data_prefix, + test, + resolver, + resolver_args, + dynamic_addl_distro, + dynamic_dest_dir, + ) @_pass_through.command("pyflyte-fast-execute") @@ -528,7 +461,7 @@ def map_execute_task_cmd( resolver, resolver_args, ): - logger.info(_utils.get_version_message()) + logger.info(get_version_message()) _execute_map_task( inputs, diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 5a6bbc4fdc..5daa009ad5 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -560,7 +560,7 @@ def create_execution(self, project, domain, name, execution_spec, inputs): def recover_execution(self, id, name: str = None): """ Recreates a previously-run workflow execution that will only start executing from the last known failure point. - :param flytekit.common.core.identifier.WorkflowExecutionIdentifier id: + :param flytekit.models.core.identifier.WorkflowExecutionIdentifier id: :param name str: Optional name to assign to the newly created execution. :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier """ @@ -572,7 +572,7 @@ def recover_execution(self, id, name: str = None): def get_execution(self, id): """ - :param flytekit.common.core.identifier.WorkflowExecutionIdentifier id: + :param flytekit.models.core.identifier.WorkflowExecutionIdentifier id: :rtype: flytekit.models.execution.Execution """ return _execution.Execution.from_flyte_idl( @@ -638,7 +638,7 @@ def list_executions_paginated(self, project, domain, limit=100, token=None, filt def terminate_execution(self, id, cause): """ - :param flytekit.common.core.identifier.WorkflowExecutionIdentifier id: + :param flytekit.models.core.identifier.WorkflowExecutionIdentifier id: :param Text cause: """ super(SynchronousFlyteClient, self).terminate_execution( @@ -647,7 +647,7 @@ def terminate_execution(self, id, cause): def relaunch_execution(self, id, name=None): """ - :param flytekit.common.core.identifier.WorkflowExecutionIdentifier id: + :param flytekit.models.core.identifier.WorkflowExecutionIdentifier id: :param Text name: [Optional] name for the new execution. If not specified, a randomly generated name will be used :returns: The unique identifier for the new execution. diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index 4e2dc71e25..2df64f080e 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -9,8 +9,8 @@ def iterate_node_executions( """ This returns a generator for node executions. :param flytekit.clients.friendly.SynchronousFlyteClient client: - :param flytekit.common.core.identifier.WorkflowExecutionIdentifier workflow_execution_identifier: - :param flytekit.common.core.identifier.TaskExecutionIdentifier task_execution_identifier: + :param flytekit.models.core.identifier.WorkflowExecutionIdentifier workflow_execution_identifier: + :param flytekit.models.core.identifier.TaskExecutionIdentifier task_execution_identifier: :param int limit: The maximum number of elements to retrieve :param list[flytekit.models.filters.Filter] filters: :rtype: Iterator[flytekit.models.node_execution.NodeExecution] diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index b8f7f578cf..ce58d8cc57 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -2,7 +2,6 @@ import time from typing import List -import six as _six from flyteidl.service import admin_pb2_grpc as _admin_service from google.protobuf.json_format import MessageToJson as _MessageToJson from grpc import RpcError as _RpcError @@ -13,13 +12,13 @@ from flytekit.clis.auth import credentials as _credentials_access from flytekit.clis.sdk_in_container import basic_auth as _basic_auth -from flytekit.common.exceptions import user as _user_exceptions from flytekit.configuration import creds as _creds_config from flytekit.configuration.creds import _DEPRECATED_CLIENT_CREDENTIALS_SCOPE as _DEPRECATED_SCOPE from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID from flytekit.configuration.creds import COMMAND as _COMMAND from flytekit.configuration.creds import DEPRECATED_OAUTH_SCOPES, SCOPES from flytekit.configuration.platform import AUTH as _AUTH +from flytekit.exceptions import user as _user_exceptions from flytekit.loggers import cli_logger @@ -91,9 +90,7 @@ def _refresh_credentials_from_command(flyte_client): output = subprocess.run(command, capture_output=True, text=True, check=True) except subprocess.CalledProcessError as e: cli_logger.error("Failed to generate token from command {}".format(command)) - raise _user_exceptions.FlyteAuthenticationException( - "Problems refreshing token with command: " + _six.text_type(e) - ) + raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e)) flyte_client.set_access_token(output.stdout.strip()) @@ -134,7 +131,7 @@ def handler(*args, **kwargs): # Always retry auth errors. if i == (max_retries - 1): # Exit the loop and wrap the authentication error. - raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e)) + raise _user_exceptions.FlyteAuthenticationException(str(e)) cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n") refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) refresh_handler_fn(args[0]) diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 438e91141f..82300a9c7f 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -20,23 +20,14 @@ from flytekit import __version__ from flytekit.clients import friendly as _friendly_client -from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map -from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map from flytekit.clis.helpers import hydrate_registration_parameters -from flytekit.clis.helpers import parse_args_into_dict as _parse_args_into_dict -from flytekit.common import launch_plan as _launch_plan_common -from flytekit.common import utils as _utils -from flytekit.common import workflow_execution as _workflow_execution_common -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import task as _tasks_common -from flytekit.common.types import helpers as _type_helpers -from flytekit.common.utils import load_proto_from_file as _load_proto_from_file from flytekit.configuration import auth as _auth_config from flytekit.configuration import platform as _platform_config from flytekit.configuration import set_flyte_config_file -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.interfaces.data.data_proxy import Data +from flytekit.core import utils +from flytekit.core.context_manager import FlyteContextManager +from flytekit.exceptions import user as _user_exceptions +from flytekit.interfaces import cli_identifiers from flytekit.models import common as _common_models from flytekit.models import filters as _filters from flytekit.models import launch_plan as _launch_plan @@ -47,8 +38,6 @@ from flytekit.models.common import RawOutputDataConfig as _RawOutputDataConfig from flytekit.models.core import execution as _core_execution_models from flytekit.models.core import identifier as _core_identifier -from flytekit.models.execution import ExecutionMetadata as _ExecutionMetadata -from flytekit.models.execution import ExecutionSpec as _ExecutionSpec from flytekit.models.matchable_resource import ClusterResourceAttributes as _ClusterResourceAttributes from flytekit.models.matchable_resource import ExecutionClusterLabel as _ExecutionClusterLabel from flytekit.models.matchable_resource import ExecutionQueueAttributes as _ExecutionQueueAttributes @@ -118,20 +107,7 @@ def _get_io_string(literal_map, verbose=False): :param bool verbose: :rtype: Text """ - value_dict = _type_helpers.unpack_literal_map_to_sdk_object(literal_map) - if value_dict: - return "\n" + "\n".join( - "{:30}: {}".format( - k, - _prefix_lines( - "{:30} ".format(""), - v.verbose_string() if verbose else v.short_string(), - ), - ) - for k, v in value_dict.items() - ) - else: - return "(None)" + return str(literal_map) def _fetch_and_stringify_literal_map(path, verbose=False): @@ -140,12 +116,13 @@ def _fetch_and_stringify_literal_map(path, verbose=False): :param bool verbose: :rtype: Text """ - with _utils.AutoDeletingTempDir("flytecli") as tmp: + ctx = FlyteContextManager.current_context() + with utils.AutoDeletingTempDir("flytecli") as tmp: try: fname = tmp.get_named_tempfile("literalmap.pb") - _data_proxy.Data.get_data(path, fname) + ctx.file_access.get_data(path, fname) literal_map = _literals.LiteralMap.from_flyte_idl( - _utils.load_proto_from_file(_literals_pb2.LiteralMap, fname) + utils.load_proto_from_file(_literals_pb2.LiteralMap, fname) ) return _get_io_string(literal_map, verbose=verbose) except Exception: @@ -254,7 +231,7 @@ def _secho_one_execution(ex, urns_only): if not urns_only: _click.echo( "{:100} {:40} {:40}".format( - _tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id)), + _tt(cli_identifiers.WorkflowExecutionIdentifier.promote_from_model(ex.id)), _tt(ex.id.name), _tt(ex.spec.launch_plan.name), ), @@ -263,7 +240,7 @@ def _secho_one_execution(ex, urns_only): _secho_workflow_status(ex.closure.phase) else: _click.echo( - "{:100}".format(_tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id))), + "{:100}".format(_tt(cli_identifiers.WorkflowExecutionIdentifier.promote_from_model(ex.id))), nl=True, ) @@ -271,7 +248,7 @@ def _secho_one_execution(ex, urns_only): def _terminate_one_execution(client, urn, cause, shouldPrint=True): if shouldPrint: _click.echo("{:100} {:40}".format(_tt(urn), _tt(cause))) - client.terminate_execution(_identifier.WorkflowExecutionIdentifier.from_python_std(urn), cause) + client.terminate_execution(cli_identifiers.WorkflowExecutionIdentifier.from_python_std(urn), cause) def _update_one_launch_plan(client: _friendly_client.SynchronousFlyteClient, urn, state): @@ -279,7 +256,7 @@ def _update_one_launch_plan(client: _friendly_client.SynchronousFlyteClient, urn state = _launch_plan.LaunchPlanState.ACTIVE else: state = _launch_plan.LaunchPlanState.INACTIVE - client.update_launch_plan(_identifier.Identifier.from_python_std(urn), state) + client.update_launch_plan(cli_identifiers.Identifier.from_python_std(urn), state) _click.echo("Successfully updated {}".format(_tt(urn))) @@ -644,7 +621,7 @@ def parse_proto(filename, proto_class): idl_obj = split[-1] mod = _importlib.import_module(idl_module) idl = getattr(mod, idl_obj) - obj = _load_proto_from_file(idl, filename) + obj = utils.load_proto_from_file(idl, filename) jsonObj = MessageToJson(obj) @@ -734,7 +711,7 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show _click.echo( "{:50} {:40}".format( _tt(t.id.version), - _tt(_identifier.Identifier.promote_from_model(t.id)), + _tt(cli_identifiers.Identifier.promote_from_model(t.id)), ) ) @@ -760,57 +737,11 @@ def get_task(urn, host, insecure): _welcome_message() parent_ctx = _click.get_current_context(silent=True) client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - t = client.get_task(_identifier.Identifier.from_python_std(urn)) + t = client.get_task(cli_identifiers.Identifier.from_python_std(urn)) _click.echo(_tt(t)) _click.echo("") -@_flyte_cli.command("launch-task", cls=_FlyteSubCommand) -@_project_option -@_domain_option -@_optional_name_option -@_assumable_iam_role_option -@_kubernetes_service_acct_option -@_host_option -@_insecure_option -@_urn_option -@_click.argument("task_args", nargs=-1, type=_click.UNPROCESSED) -def launch_task(project, domain, name, assumable_iam_role, kubernetes_service_account, host, insecure, urn, task_args): - """ - Kick off a single task execution. Note that the {project, domain, name} specified in the command line - will be for the execution. The project/domain for the task are specified in the urn. - - Use a -- to separate arguments to this cli, and arguments to the task. - e.g. - $ flyte-cli -h localhost:30081 -p flyteexamples -d development launch-task \ - -u tsk:flyteexamples:development:some-task:abc123 -- input=hi \ - other-input=123 moreinput=qwerty - - These arguments are then collected, and passed into the `task_args` variable as a Tuple[Text]. - Users should use the get-task command to ascertain the names of inputs to use. - """ - _welcome_message() - auth_role = _AuthRole(assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account) - - with _platform_config.URL.get_patcher(host), _platform_config.INSECURE.get_patcher(_tt(insecure)): - task_id = _identifier.Identifier.from_python_std(urn) - task = _tasks_common.SdkTask.fetch(task_id.project, task_id.domain, task_id.name, task_id.version) - - text_args = _parse_args_into_dict(task_args) - inputs = {} - for var_name, variable in task.interface.inputs.items(): - sdk_type = _type_helpers.get_sdk_type_from_literal_type(variable.type) - if var_name in text_args and text_args[var_name] is not None: - inputs[var_name] = sdk_type.from_string(text_args[var_name]).to_python_std() - - # TODO: Implement notification overrides - # TODO: Implement label overrides - # TODO: Implement annotation overrides - execution = task.launch(project, domain, inputs=inputs, name=name, auth_role=auth_role) - _click.secho("Launched execution: {}".format(_tt(execution.id)), fg="blue") - _click.echo("") - - ######################################################################################################################## # # Workflow Commands @@ -892,7 +823,7 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, _click.echo( "{:50} {:40}".format( _tt(w.id.version), - _tt(_identifier.Identifier.promote_from_model(w.id)), + _tt(cli_identifiers.Identifier.promote_from_model(w.id)), ) ) @@ -918,7 +849,7 @@ def get_workflow(urn, host, insecure): _welcome_message() parent_ctx = _click.get_current_context(silent=True) client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - _click.echo(client.get_workflow(_identifier.Identifier.from_python_std(urn))) + _click.echo(client.get_workflow(cli_identifiers.Identifier.from_python_std(urn))) # TODO: Print workflow pretty _click.echo("") @@ -1003,13 +934,13 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show for lp in active_lps: if urns_only: - _click.echo("{:80}".format(_tt(_identifier.Identifier.promote_from_model(lp.id)))) + _click.echo("{:80}".format(_tt(cli_identifiers.Identifier.promote_from_model(lp.id)))) else: _click.echo( "{:30} {:50} {:80}".format( _render_schedule_expr(lp), _tt(lp.id.version), - _tt(_identifier.Identifier.promote_from_model(lp.id)), + _tt(cli_identifiers.Identifier.promote_from_model(lp.id)), ), ) @@ -1072,12 +1003,12 @@ def list_launch_plan_versions( ) for l in lp_list: if urns_only: - _click.echo(_tt(_identifier.Identifier.promote_from_model(l.id))) + _click.echo(_tt(cli_identifiers.Identifier.promote_from_model(l.id))) else: _click.echo( "{:50} {:80} ".format( _tt(l.id.version), - _tt(_identifier.Identifier.promote_from_model(l.id)), + _tt(cli_identifiers.Identifier.promote_from_model(l.id)), ), nl=False, ) @@ -1115,7 +1046,7 @@ def get_launch_plan(urn, host, insecure): _welcome_message() parent_ctx = _click.get_current_context(silent=True) client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - _click.echo(_tt(client.get_launch_plan(_identifier.Identifier.from_python_std(urn)))) + _click.echo(_tt(client.get_launch_plan(cli_identifiers.Identifier.from_python_std(urn)))) # TODO: Print launch plan pretty _click.echo("") @@ -1167,50 +1098,6 @@ def update_launch_plan(state, host, insecure, urn=None): _update_one_launch_plan(client, urn=urn, state=state) -@_flyte_cli.command("execute-launch-plan", cls=_FlyteSubCommand) -@_project_option -@_domain_option -@_optional_name_option -@_host_option -@_insecure_option -@_urn_option -@_principal_option -@_verbose_option -@_watch_option -@_click.argument("lp_args", nargs=-1, type=_click.UNPROCESSED) -def execute_launch_plan(project, domain, name, host, insecure, urn, principal, verbose, watch, lp_args): - """ - Kick off a launch plan. Note that the {project, domain, name} specified in the command line - will be for the execution. The project/domain for the launch plan are specified in the urn. - - Use a -- to separate arguments to this cli, and arguments to the launch plan. - e.g. - $ flyte-cli -h localhost:30081 -p flyteexamples -d development execute-launch-plan \ - --verbose --principal=sdk-demo - -u lp:flyteexamples:development:some-workflow:abc123 -- input=hi \ - other-input=123 moreinput=qwerty - - These arguments are then collected, and passed into the `lp_args` variable as a Tuple[Text]. - Users should use the get-launch-plan command to ascertain the names of inputs to use. - """ - _welcome_message() - - with _platform_config.URL.get_patcher(host), _platform_config.INSECURE.get_patcher(_tt(insecure)): - lp_id = _identifier.Identifier.from_python_std(urn) - lp = _launch_plan_common.SdkLaunchPlan.fetch(lp_id.project, lp_id.domain, lp_id.name, lp_id.version) - - inputs = _construct_literal_map_from_parameter_map(lp.default_inputs, _parse_args_into_dict(lp_args)) - # TODO: Implement notification overrides - # TODO: Implement label overrides - # TODO: Implement annotation overrides - execution = lp.launch_with_literals(project, domain, inputs, name=name) - _click.secho("Launched execution: {}".format(_tt(execution.id)), fg="blue") - _click.echo("") - - if watch is True: - execution.wait_for_completion() - - ######################################################################################################################## # # Execution Commands @@ -1218,113 +1105,6 @@ def execute_launch_plan(project, domain, name, host, insecure, urn, principal, v ######################################################################################################################## -@_flyte_cli.command("watch-execution", cls=_FlyteSubCommand) -@_host_option -@_insecure_option -@_urn_option -def watch_execution(host, insecure, urn): - """ - Wait for an execution to complete. - - e.g. - $ flyte-cli -h localhost:30081 watch-execution -u ex:flyteexamples:development:abc123 - """ - _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - ex_id = _identifier.WorkflowExecutionIdentifier.from_python_std(urn) - - execution = _workflow_execution_common.SdkWorkflowExecution.promote_from_model(client.get_execution(ex_id)) - - _click.echo("Waiting for the execution {} to complete ...".format(_tt(execution.id))) - - with _platform_config.URL.get_patcher(host), _platform_config.INSECURE.get_patcher(_tt(insecure)): - execution.wait_for_completion() - - -@_flyte_cli.command("relaunch-execution", cls=_FlyteSubCommand) -@_optional_project_option -@_optional_domain_option -@_optional_name_option -@_host_option -@_insecure_option -@_urn_option -@_optional_principal_option -@_verbose_option -@_click.argument("lp_args", nargs=-1, type=_click.UNPROCESSED) -def relaunch_execution(project, domain, name, host, insecure, urn, principal, verbose, lp_args): - """ - Relaunch a launch plan. - As with kicking off a launch plan (see execute-launch-plan), the project and domain will correspond to the new - execution to be run, and the project/domain used to find the existing execution will come from the URN. - This means you can re-run a development execution, in production, off of a staging launch-plan (in another project), - but beware that execution environment configurations can result in slower executions or permissions failures. - Therefore, it is recommended to re-run in the same environment as the original execution. By default, if the - project and domain are not specified, the existing project/domain will be used. - - When relaunching an execution, this will display the fixed inputs that it ran with (from the launch plan spec), - and handle the other inputs similar to how we handle initial launch plan execution, except that - all inputs now will have a default (the input of the execution being rerun). - - Use a -- to separate arguments to this cli, and arguments to the launch plan. - e.g. - $ flyte-cli -h localhost:30081 -p flyteexamples -d development execute-launch-plan \ - -u lp:flyteexamples:development:some-workflow:abc123 -- input=hi \ - other-input=123 moreinput=qwerty - - These arguments are then collected, and passed into the `lp_args` variable as a Tuple[Text]. - Users should use the get-execution and get-launch-plan commands to ascertain the names of inputs to use. - """ - _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - - _click.echo("Relaunching execution {}\n".format(_tt(urn))) - existing_workflow_execution_identifier = _identifier.WorkflowExecutionIdentifier.from_python_std(urn) - e = client.get_execution(existing_workflow_execution_identifier) - - if project is None: - project = existing_workflow_execution_identifier.project - if domain is None: - domain = existing_workflow_execution_identifier.domain - if principal is None: - principal = e.spec.metadata.principal - - lp_model = client.get_launch_plan(e.spec.launch_plan) - expected_inputs = lp_model.closure.expected_inputs - - # Parse text inputs using the LP closure's parameter map to determine types. However, since all inputs are now - # optional (because we can default to the original execution's), we reduce first to bare Variables. - variable_map = {k: v.var for k, v in expected_inputs.parameters.items()} - parsed_text_args = _parse_args_into_dict(lp_args) - new_inputs = _construct_literal_map_from_variable_map(variable_map, parsed_text_args) - if len(new_inputs.literals) > 0: - _click.secho("\tNew Inputs: {}\n".format(_prefix_lines("\t\t", _get_io_string(new_inputs, verbose=verbose)))) - - # Construct new inputs from existing execution inputs and new inputs - inputs_dict = {} - for k in e.spec.inputs.literals.keys(): - if k in new_inputs.literals: - inputs_dict[k] = new_inputs.literals[k] - else: - inputs_dict[k] = e.spec.inputs.literals[k] - inputs = _literals.LiteralMap(literals=inputs_dict) - - if len(inputs_dict) > 0: - _click.secho( - "\tFinal Inputs for New Execution: {}\n".format( - _prefix_lines("\t\t", _get_io_string(inputs, verbose=verbose)) - ) - ) - - metadata = _ExecutionMetadata(mode=_ExecutionMetadata.ExecutionMode.MANUAL, principal=principal, nesting=0) - ex_spec = _ExecutionSpec(launch_plan=lp_model.id, inputs=inputs, metadata=metadata) - execution_identifier = client.create_execution(project=project, domain=domain, name=name, execution_spec=ex_spec) - execution_identifier = _identifier.WorkflowExecutionIdentifier.promote_from_model(execution_identifier) - _click.secho("Launched execution: {}".format(execution_identifier), fg="blue") - _click.echo("") - - @_flyte_cli.command("recover-execution", cls=_FlyteSubCommand) @_urn_option @_optional_name_option @@ -1356,10 +1136,10 @@ def recover_execution(urn, name, host, insecure): _click.echo("Recovering execution {}\n".format(_tt(urn))) - original_workflow_execution_identifier = _identifier.WorkflowExecutionIdentifier.from_python_std(urn) + original_workflow_execution_identifier = cli_identifiers.WorkflowExecutionIdentifier.from_python_std(urn) execution_identifier_resp = client.recover_execution(id=original_workflow_execution_identifier, name=name) - execution_identifier = _identifier.WorkflowExecutionIdentifier.promote_from_model(execution_identifier_resp) + execution_identifier = cli_identifiers.WorkflowExecutionIdentifier.promote_from_model(execution_identifier_resp) _click.secho("Launched execution: {}".format(execution_identifier), fg="blue") _click.echo("") @@ -1500,7 +1280,7 @@ def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbos _click.echo( "\t{:15} {}".format( "Launch Plan:", - _tt(_identifier.Identifier.promote_from_model(wf_execution.spec.launch_plan)), + _tt(cli_identifiers.Identifier.promote_from_model(wf_execution.spec.launch_plan)), ) ) @@ -1675,7 +1455,7 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure "Subtasks:", "flyte-cli get-child-executions -h {host}{insecure} -u {urn}".format( host=host, - urn=_tt(_identifier.TaskExecutionIdentifier.promote_from_model(te.id)), + urn=_tt(cli_identifiers.TaskExecutionIdentifier.promote_from_model(te.id)), insecure=" --insecure" if insecure else "", ), ) @@ -1699,7 +1479,7 @@ def get_execution(urn, host, insecure, show_io, verbose): _welcome_message() parent_ctx = _click.get_current_context(silent=True) client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - e = client.get_execution(_identifier.WorkflowExecutionIdentifier.from_python_std(urn)) + e = client.get_execution(cli_identifiers.WorkflowExecutionIdentifier.from_python_std(urn)) node_execs = _get_all_node_executions(client, workflow_execution_identifier=e.id) _render_node_executions(client, node_execs, show_io, verbose, host, insecure, wf_execution=e) @@ -1716,7 +1496,7 @@ def get_child_executions(urn, host, insecure, show_io, verbose): client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) node_execs = _get_all_node_executions( client, - task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn), + task_execution_identifier=cli_identifiers.TaskExecutionIdentifier.from_python_std(urn), ) _render_node_executions(client, node_execs, show_io, verbose, host, insecure) @@ -1838,7 +1618,7 @@ def _extract_pair( f"Resource type found in proto file name [{resource_type}] invalid, " "must be 1 (task), 2 (workflow) or 3 (launch plan)" ) - entity = _load_proto_from_file(_resource_map[resource_type], object_file) + entity = utils.load_proto_from_file(_resource_map[resource_type], object_file) registerable_identifier, registerable_entity = hydrate_registration_parameters( resource_type, project, domain, version, entity ) @@ -2079,7 +1859,8 @@ def fast_register_files( version = version if version else digest full_remote_path = _get_additional_distribution_loc(additional_distribution_dir, version) - Data.put_data(compressed_source, full_remote_path) + ctx = FlyteContextManager.current_context() + ctx.file_access.put_data(compressed_source, full_remote_path) _click.secho(f"Uploaded compressed code archive {compressed_source} to {full_remote_path}", fg="green") def fast_register_task(entity: _GeneratedProtocolMessageType) -> _GeneratedProtocolMessageType: diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index c0ab7397c2..f87a7d2a11 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -7,34 +7,6 @@ from flyteidl.core import workflow_pb2 as _workflow_pb2 from flytekit.clis.sdk_in_container.serialize import _DOMAIN_PLACEHOLDER, _PROJECT_PLACEHOLDER, _VERSION_PLACEHOLDER -from flytekit.common.types.helpers import get_sdk_type_from_literal_type as _get_sdk_type_from_literal_type -from flytekit.models import literals as _literals - - -def construct_literal_map_from_variable_map(variable_dict, text_args): - """ - This function produces a map of Literals to use when creating an execution. It reads the required values from - a Variable map (presumably obtained from a launch plan), and then fills in the necessary inputs - from the click args. Click args will be strings, which will be parsed into their SDK types with each - SDK type's parse string method. - - :param dict[Text, flytekit.models.interface.Variable] variable_dict: - :param dict[Text, Text] text_args: - :rtype: flytekit.models.literals.LiteralMap - """ - inputs = {} - - for var_name, variable in variable_dict.items(): - # Check to see if it's passed from click - # If this is an input that has a default from the LP, it should've already been parsed into a string, - # and inserted into the default for this option, so it should still be here. - if var_name in text_args and text_args[var_name] is not None: - # the SDK type is also available from the sdk workflow object's saved user inputs but let's - # derive it here from the interface to be more rigorous. - sdk_type = _get_sdk_type_from_literal_type(variable.type) - inputs[var_name] = sdk_type.from_string(text_args[var_name]) - - return _literals.LiteralMap(literals=inputs) def parse_args_into_dict(input_arguments): @@ -49,36 +21,6 @@ def parse_args_into_dict(input_arguments): return {split_arg[0]: split_arg[1] for split_arg in [input_arg.split("=", 1) for input_arg in input_arguments]} -def construct_literal_map_from_parameter_map(parameter_map, text_args): - """ - Take a dictionary of Text to Text and construct a literal map using a ParameterMap as guidance. - Required input parameters must have an entry in the text arguments given. - Parameters with defaults will have those defaults filled in if missing from the text arguments. - - :param flytekit.models.interface.ParameterMap parameter_map: - :param dict[Text, Text] text_args: - :rtype: flytekit.models.literals.LiteralMap - """ - - # This function can be written by calling construct_literal_map_from_variable_map also, but not that much - # code is saved. - inputs = {} - for var_name, parameter in parameter_map.parameters.items(): - sdk_type = _get_sdk_type_from_literal_type(parameter.var.type) - if parameter.required: - if var_name in text_args and text_args[var_name] is not None: - inputs[var_name] = sdk_type.from_string(text_args[var_name]) - else: - raise Exception("Missing required parameter {}".format(var_name)) - else: - if var_name in text_args and text_args[var_name] is not None: - inputs[var_name] = sdk_type.from_string(text_args[var_name]) - else: - inputs[var_name] = parameter.default - - return _literals.LiteralMap(literals=inputs) - - def str2bool(str): """ bool('False') is True in Python, so we need to do some string parsing. Use the same words in ConfigParser diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py index 8806c5c406..92612595ad 100644 --- a/flytekit/clis/sdk_in_container/basic_auth.py +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -3,8 +3,8 @@ import requests as _requests -from flytekit.common.exceptions.user import FlyteAuthenticationException as _FlyteAuthenticationException from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET +from flytekit.exceptions.user import FlyteAuthenticationException _utf_8 = "utf-8" @@ -18,7 +18,7 @@ def get_secret(): secret = _CREDENTIALS_SECRET.get() if secret: return secret - raise _FlyteAuthenticationException("No secret could be found") + raise FlyteAuthenticationException("No secret could be found") def get_basic_authorization_header(client_id, client_secret): @@ -55,7 +55,7 @@ def get_token(token_endpoint, authorization_header, scope): response = _requests.post(token_endpoint, data=body, headers=headers) if response.status_code != 200: _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise _FlyteAuthenticationException("Non-200 received from IDP") + raise FlyteAuthenticationException("Non-200 received from IDP") response = response.json() return response["access_token"], response["expires_in"] diff --git a/flytekit/clis/sdk_in_container/fast_register.py b/flytekit/clis/sdk_in_container/fast_register.py deleted file mode 100644 index 944ef502bf..0000000000 --- a/flytekit/clis/sdk_in_container/fast_register.py +++ /dev/null @@ -1,249 +0,0 @@ -import os as _os -from typing import List as _List - -import click - -from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_TEST -from flytekit.common import utils as _utils -from flytekit.common.core import identifier as _identifier -from flytekit.common.tasks import sdk_runnable as _sdk_runnable_task -from flytekit.common.tasks import task as _task -from flytekit.configuration import sdk as _sdk_config -from flytekit.tools.fast_registration import compute_digest as _compute_digest -from flytekit.tools.fast_registration import get_additional_distribution_loc as _get_additional_distribution_loc -from flytekit.tools.fast_registration import upload_package as _upload_package -from flytekit.tools.module_loader import iterate_registerable_entities_in_order - - -def fast_register_all( - project: str, - domain: str, - pkgs: _List[str], - test: bool, - version: str, - source_dir: _os.PathLike, - dest_dir: _os.PathLike = None, -): - if test: - click.echo("Test switch enabled, not doing anything...") - - if not version: - digest = _compute_digest(source_dir) - else: - digest = version - remote_package_path = _upload_package(source_dir, digest, _sdk_config.FAST_REGISTRATION_DIR.get()) - - click.echo( - "Running task, workflow, and launch plan fast registration for {}, {}, {} with version {} and code dir {}".format( - project, domain, pkgs, digest, source_dir - ) - ) - - # m = module (i.e. python file) - # k = value of dir(m), type str - # o = object (e.g. SdkWorkflow) - for m, k, o in iterate_registerable_entities_in_order(pkgs): - name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) - o._id = _identifier.Identifier(o.resource_type, project, domain, name, digest) - - if test: - click.echo("Would fast register {:20} {}".format("{}:".format(o.entity_type_text), o.id.name)) - else: - click.echo("Fast registering {:20} {}".format("{}:".format(o.entity_type_text), o.id.name)) - _get_additional_distribution_loc(_sdk_config.FAST_REGISTRATION_DIR.get(), digest) - if isinstance(o, _sdk_runnable_task.SdkRunnableTask): - o.fast_register(project, domain, o.id.name, digest, remote_package_path, dest_dir) - else: - o.register(project, domain, o.id.name, digest) - - -def fast_register_tasks_only( - project: str, - domain: str, - pkgs: _List[str], - test: bool, - version: str, - source_dir: _os.PathLike, - dest_dir: _os.PathLike = None, -): - if test: - click.echo("Test switch enabled, not doing anything...") - - if not version: - digest = _compute_digest(source_dir) - else: - digest = version - remote_package_path = _upload_package(source_dir, digest, _sdk_config.FAST_REGISTRATION_DIR.get()) - - click.echo( - "Running task only fast registration for {}, {}, {} with version {} and code dir {}".format( - project, domain, pkgs, digest, source_dir - ) - ) - - # Discover all tasks by loading the module - for m, k, t in iterate_registerable_entities_in_order(pkgs, include_entities={_task.SdkTask}): - name = _utils.fqdn(m.__name__, k, entity_type=t.resource_type) - - if test: - click.echo("Would fast register task {:20} {}".format("{}:".format(t.entity_type_text), name)) - else: - click.echo("Fast registering task {:20} {}".format("{}:".format(t.entity_type_text), name)) - if isinstance(t, _sdk_runnable_task.SdkRunnableTask): - t.fast_register(project, domain, name, digest, remote_package_path, dest_dir) - else: - t.register(project, domain, name, digest) - - -@click.group("fast-register") -@click.option("--test", is_flag=True, help="Dry run, do not actually register with Admin") -@click.pass_context -def fast_register(ctx, test=None): - """ - Run fast registration steps for the Flyte entities in this container. This is an optimization to avoid the - conventional container build and upload cycle. This can be useful for fast iteration when making code changes. - If you do need to change the container itself (e.g. by adding a new dependency/import) you must rebuild and - upload a container. - - Caveats: Your flyte config must specify a fast registration dir like so: - [sdk] - fast_registration_dir=s3://my-s3-bucket/dir - - **and** ensure that the role specified in [auth] section of your config has read access to this remote location. - Furthermore, the role you assume to call fast-register must have **write** permission to this remote location. - - Run with the --test switch for a dry run to see what will be registered. A default launch plan will also be - created, if a role can be found in the environment variables. - """ - - ctx.obj[CTX_TEST] = test - - -@click.command("tasks") -@click.option( - "--source-dir", - type=str, - help="The root dir of the code that should be uploaded for fast registration.", - required=True, -) -@click.option( - "--dest-dir", - type=str, - help="[Optional] The output directory of code which is downloaded during fast registration. " - "If the current working directory at the time of installation is not desired", -) -@click.option( - "-v", - "--version", - type=str, - help="Version to register tasks with. This is normally computed deterministically from your code, " - "but you can override here.", -) -@click.pass_context -def tasks(ctx, source_dir, dest_dir=None, version=None): - """ - Only fast register tasks. - - For example, consider a sample directory where tasks defined in workflows/ imports code from util/ like so: - - \b - $ tree /root/code/ - /root/code/ - ├── Dockerfile - ├── Makefile - ├── README.md - ├── conf.py - ├── notebook.config - ├── workflows - │   ├── __init__.py - │   ├── compose - │   │   ├── README.md - │   │   ├── __init__.py - │   │   ├── a_workflow.py - │   │   ├── b_workflow.py - ├── util - │   ├── __init__.py - │   ├── shared_task_code.py - ├── requirements.txt - ├── flyte.config - - Your source dir will need to be /root/code/ rather than the workflow packages dir /root/code/workflows you might - have specified in your flyte.config because all of the code your workflows depends on needs to be encapsulated in - `source_dir`, like so: - - pyflyte -p myproject -d development fast-register tasks --source-dir /root/code/ - - """ - project = ctx.obj[CTX_PROJECT] - domain = ctx.obj[CTX_DOMAIN] - test = ctx.obj[CTX_TEST] - pkgs = ctx.obj[CTX_PACKAGES] - - fast_register_tasks_only(project, domain, pkgs, test, version, source_dir, dest_dir) - - -@click.command("workflows") -@click.option( - "--source-dir", - type=str, - help="The root dir of the code that should be uploaded for fast registration.", - required=True, -) -@click.option( - "--dest-dir", - type=str, - help="[Optional] The output directory of code which is downloaded during fast registration. " - "If the current working directory at the time of installation is not desired", -) -@click.option( - "-v", - "--version", - type=str, - help="Version to register entities with. This is normally computed deterministically from your code, " - "but you can override here.", -) -@click.pass_context -def workflows(ctx, source_dir, dest_dir=None, version=None): - """ - Fast register both tasks and workflows. Also create and register a default launch plan for all workflows. - The `source_dir` param should point to the root directory of your project that contains all of your working code. - - For example, consider a sample directory structure where code in workflows/ imports code from util/ like so: - - \b - $ tree /root/code/ - /root/code/ - ├── Dockerfile - ├── Makefile - ├── README.md - ├── conf.py - ├── notebook.config - ├── workflows - │   ├── __init__.py - │   ├── compose - │   │   ├── README.md - │   │   ├── __init__.py - │   │   ├── a_workflow.py - │   │   ├── b_workflow.py - ├── util - │   ├── __init__.py - │   ├── shared_workflow_code.py - ├── requirements.txt - ├── flyte.config - - Your source dir will need to be /root/code/ rather than the workflow packages dir /root/code/workflows you might - have specified in your flyte.config because all of the code your workflows depends on needs to be encapsulated in - `source_dir`, like so: - - pyflyte -p myproject -d development fast-register workflows --source-dir /root/code/ - """ - project = ctx.obj[CTX_PROJECT] - domain = ctx.obj[CTX_DOMAIN] - test = ctx.obj[CTX_TEST] - pkgs = ctx.obj[CTX_PACKAGES] - - fast_register_all(project, domain, pkgs, test, version, source_dir, dest_dir) - - -fast_register.add_command(tasks) -fast_register.add_command(workflows) diff --git a/flytekit/clis/sdk_in_container/launch_plan.py b/flytekit/clis/sdk_in_container/launch_plan.py deleted file mode 100644 index e367fa08d8..0000000000 --- a/flytekit/clis/sdk_in_container/launch_plan.py +++ /dev/null @@ -1,276 +0,0 @@ -import logging as _logging -import os as _os - -import click -import six as _six - -from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map -from flytekit.clis.sdk_in_container import constants as _constants -from flytekit.clis.sdk_in_container.constants import ( - CTX_DOMAIN, - CTX_PROJECT, - CTX_VERSION, - domain_option, - project_option, - version_option, -) -from flytekit.common import utils as _utils -from flytekit.common.launch_plan import SdkLaunchPlan as _SdkLaunchPlan -from flytekit.configuration.internal import DOMAIN as _DOMAIN -from flytekit.configuration.internal import IMAGE as _IMAGE -from flytekit.configuration.internal import PROJECT as _PROJECT -from flytekit.configuration.internal import VERSION as _VERSION -from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models.core import identifier as _identifier -from flytekit.tools.module_loader import iterate_registerable_entities_in_order - - -class LaunchPlanAbstractGroup(click.Group): - """ - This class iterates over the workflow folders and loads all workflows that are implemented via the programming - model. - """ - - def __init__(self, name, **attrs): - super(LaunchPlanAbstractGroup, self).__init__(name, commands=None, **attrs) - - def list_commands(self, ctx): - commands = [] - lps = {} - pkgs = ctx.obj[_constants.CTX_PACKAGES] - # Discover all launch plans by loading the modules - for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False - ): - safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) - commands.append(safe_name) - lps[safe_name] = lp - - ctx.obj["lps"] = lps - commands.sort() - - return commands - - def get_command(self, ctx, lp_argument): - # Get the launch plan object in one of two ways. If get_command is being called by the list function - # then it should have been cached in the context. - # If we are actually running the command, then it won't have been cached and we'll have to load everything again - launch_plan = None - pkgs = ctx.obj[_constants.CTX_PACKAGES] - - if "lps" in ctx.obj: - launch_plan = ctx.obj["lps"][lp_argument] - else: - for m, k, lp in iterate_registerable_entities_in_order( - pkgs, - include_entities={_SdkLaunchPlan}, - detect_unreferenced_entities=False, - ): - safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) - if lp_argument == safe_name: - launch_plan = lp - - if launch_plan is None: - raise Exception("Could not load launch plan {}".format(lp_argument)) - - launch_plan._id = _identifier.Identifier( - _identifier.ResourceType.LAUNCH_PLAN, - ctx.obj[_constants.CTX_PROJECT], - ctx.obj[_constants.CTX_DOMAIN], - lp_argument, - ctx.obj[_constants.CTX_VERSION], - ) - return self._get_command(ctx, launch_plan, lp_argument) - - def _get_command(self, ctx, lp, cmd_name): - """ - :param ctx: - :param flytekit.common.launch_plan.SdkLaunchPlan lp: - :rtype: click.Command - """ - pass - - -class LaunchPlanExecuteGroup(LaunchPlanAbstractGroup): - def _get_command(self, ctx, lp, cmd_name): - """ - This function returns the function that click will actually use to execute a specific launch plan. It also - stores the launch plan python object and the command name in the closure. - - :param ctx: - :param flytekit.common.launch_plan.SdkLaunchPlan lp: - :param Text cmd_name: The name of the launch plan, as passed in from the abstract class - """ - - def _execute_lp(**kwargs): - for input_name in _six.iterkeys(kwargs): - if isinstance(kwargs[input_name], tuple): - kwargs[input_name] = list(kwargs[input_name]) - - inputs = _construct_literal_map_from_parameter_map(lp.default_inputs, kwargs) - execution = lp.execute_with_literals( - ctx.obj[_constants.CTX_PROJECT], - ctx.obj[_constants.CTX_DOMAIN], - literal_inputs=inputs, - notification_overrides=ctx.obj.get(_constants.CTX_NOTIFICATIONS, None), - ) - click.echo( - click.style( - "Workflow scheduled, execution_id={}".format(_six.text_type(execution.id)), - fg="blue", - ) - ) - - command = click.Command(name=cmd_name, callback=_execute_lp) - - # Iterate through the workflow's inputs - for var_name in sorted(lp.default_inputs.parameters): - param = lp.default_inputs.parameters[var_name] - # TODO: Figure out how to better handle the fact that we want strings to parse, - # but we probably shouldn't have click say that that's the type on the CLI. - help_msg = "{} Type: {}".format( - _six.text_type(param.var.description), _six.text_type(param.var.type) - ).strip() - - if param.required: - # If it's a required input, add the required flag - wrapper = click.option( - "--{}".format(var_name), - required=True, - type=_six.text_type, - help=help_msg, - ) - else: - # If it's not a required input, it should have a default - # Use to_python_std so that the text of the default ends up being parseable, if not, the click - # arg would look something like 'Integer(10)'. If the user specified '11' on the cli, then - # we'd get '11' and then we'd need annoying logic to differentiate between the default text - # and user text. - default = param.default.to_python_std() - wrapper = click.option( - "--{}".format(var_name), - default="{}".format(_six.text_type(default)), - type=_six.text_type, - help="{}. Default: {}".format(help_msg, _six.text_type(default)), - ) - - command = wrapper(command) - - return command - - -@click.group("lp") -@project_option -@domain_option -@version_option -@click.pass_context -def launch_plans(ctx, project, domain, version): - """ - Launch plan control group, including executions - """ - - version = version or _look_up_version_from_image_tag(_IMAGE.get()) - if not version: - raise click.UsageError("Could not find image from config, please specify a value for ``--version``") - - ctx.obj[CTX_PROJECT] = project - ctx.obj[CTX_DOMAIN] = domain - ctx.obj[CTX_VERSION] = version - _os.environ[_PROJECT.env_var] = project - _os.environ[_DOMAIN.env_var] = domain - _os.environ[_VERSION.env_var] = version - - -@click.group("execute", cls=LaunchPlanExecuteGroup) -@click.pass_context -def execute_launch_plan(ctx): - """ - Execute launch plans found in this container - """ - pass - - -def activate_all_impl(project, domain, version, pkgs, ignore_schedules=False): - # TODO: This should be a transaction to ensure all or none are updated - # TODO: We should optionally allow deactivation of missing launch plans - - # Discover all launch plans by loading the modules - _logging.info(f"Setting this version's {version} launch plans active in {project} {domain}") - for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False - ): - lp._id = _identifier.Identifier( - _identifier.ResourceType.LAUNCH_PLAN, - project, - domain, - _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), - version, - ) - if not (lp.is_scheduled and ignore_schedules): - _logging.info(f"Setting active {_utils.fqdn(m.__name__, k, entity_type=lp.resource_type)}") - lp.update(_launch_plan_model.LaunchPlanState.ACTIVE) - - -@click.command("activate-all-schedules") -@click.option( - "-v", - "--version", - type=str, - help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", -) -@click.pass_context -def activate_all_schedules(ctx, version=None): - """ - THIS COMMAND IS DEPRECATED. PLEASE USE activate-all - - The behavior of this command is identical to activate-all. - """ - click.secho( - "activate-all-schedules is deprecated, please use activate-all instead.", - color="yellow", - ) - project = ctx.obj[_constants.CTX_PROJECT] - domain = ctx.obj[_constants.CTX_DOMAIN] - pkgs = ctx.obj[_constants.CTX_PACKAGES] - version = version or ctx.obj[_constants.CTX_VERSION] or _look_up_version_from_image_tag(_IMAGE.get()) - activate_all_impl(project, domain, version, pkgs) - - -@click.command("activate-all") -@click.option( - "-v", - "--version", - type=str, - help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", -) -@click.option( - "--ignore-schedules", - is_flag=True, - help="Activate all except for launch plans with schedules.", -) -@click.pass_context -def activate_all(ctx, version=None, ignore_schedules=False): - """ - This command will activate all found launch plans at the given version. If there are existing - active launch plans that collide on project, domain, and name, but differ on version, those will be - deactivated in favor of the version specified in this command. If a launch plan is associated with a schedule, - the schedule will also be deactivated or activated as appropriate. - - Note: - 1. Currently, this is not a transaction. Therefore, if the command fails, it is possible that some schedules - have been updated. - 2. If a launch plan is scheduled on an older version for a given project, domain, and name AND there is not a - matching scheduled launch plan found when running this command, the existing schedule will remain active - until it is manually disabled. - """ - project = ctx.obj[_constants.CTX_PROJECT] - domain = ctx.obj[_constants.CTX_DOMAIN] - pkgs = ctx.obj[_constants.CTX_PACKAGES] - version = version or ctx.obj[_constants.CTX_VERSION] or _look_up_version_from_image_tag(_IMAGE.get()) - activate_all_impl(project, domain, version, pkgs, ignore_schedules=ignore_schedules) - - -launch_plans.add_command(execute_launch_plan) -launch_plans.add_command(activate_all_schedules) -launch_plans.add_command(activate_all) diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index b23fa5d121..e0d3859122 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -5,12 +5,9 @@ import click from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES -from flytekit.clis.sdk_in_container.fast_register import fast_register from flytekit.clis.sdk_in_container.init import init -from flytekit.clis.sdk_in_container.launch_plan import launch_plans from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package -from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.serialize import serialize from flytekit.configuration import internal as _internal_config from flytekit.configuration import platform as _platform_config @@ -107,10 +104,7 @@ def update_configuration_file(config_file_path): click.secho("Flyte Admin URL {}".format(_URL.get()), fg="green") -main.add_command(register) -main.add_command(fast_register) main.add_command(serialize) -main.add_command(launch_plans) main.add_command(package) main.add_command(local_cache) main.add_command(init) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py deleted file mode 100644 index 876629b851..0000000000 --- a/flytekit/clis/sdk_in_container/register.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging as _logging -import os as _os - -import click - -from flytekit.clis.sdk_in_container.constants import ( - CTX_DOMAIN, - CTX_PACKAGES, - CTX_PROJECT, - CTX_TEST, - CTX_VERSION, - domain_option, - project_option, - version_option, -) -from flytekit.common import utils as _utils -from flytekit.common.core import identifier as _identifier -from flytekit.common.tasks import task as _task -from flytekit.configuration.internal import DOMAIN as _DOMAIN -from flytekit.configuration.internal import IMAGE as _IMAGE -from flytekit.configuration.internal import PROJECT as _PROJECT -from flytekit.configuration.internal import VERSION as _VERSION -from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag -from flytekit.tools.module_loader import iterate_registerable_entities_in_order - - -def register_all(project, domain, pkgs, test, version): - if test: - click.echo("Test switch enabled, not doing anything...") - click.echo( - "Running task, workflow, and launch plan registration for {}, {}, {} with version {}".format( - project, domain, pkgs, version - ) - ) - - # m = module (i.e. python file) - # k = value of dir(m), type str - # o = object (e.g. SdkWorkflow) - for m, k, o in iterate_registerable_entities_in_order(pkgs): - name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) - _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier(o.resource_type, project, domain, name, version) - - if test: - click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), o.id.name)) - else: - click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), o.id.name)) - o.register(project, domain, o.id.name, version) - - -def register_tasks_only(project, domain, pkgs, test, version): - if test: - click.echo("Test switch enabled, not doing anything...") - - click.echo("Running task only registration for {}, {}, {} with version {}".format(project, domain, pkgs, version)) - - # Discover all tasks by loading the module - for m, k, t in iterate_registerable_entities_in_order(pkgs, include_entities={_task.SdkTask}): - name = _utils.fqdn(m.__name__, k, entity_type=t.resource_type) - - if test: - click.echo("Would register task {:20} {}".format("{}:".format(t.entity_type_text), name)) - else: - click.echo("Registering task {:20} {}".format("{}:".format(t.entity_type_text), name)) - t.register(project, domain, name, version) - - -@click.group("register") -@project_option -@domain_option -@version_option -# --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead -@click.option( - "--pkgs", - multiple=True, - help="DEPRECATED. This arg can only be used before the 'register' keyword", -) -@click.option("--test", is_flag=True, help="Dry run, do not actually register with Admin") -@click.pass_context -def register(ctx, project, domain, version, pkgs=None, test=None): - """ - Run registration steps for the workflows in this container. - - Run with the --test switch for a dry run to see what will be registered. A default launch plan will also be - created, if a role can be found in the environment variables. - """ - if pkgs: - raise click.UsageError("--pkgs must now be specified before the 'register' keyword on the command line") - - version = version or _look_up_version_from_image_tag(_IMAGE.get()) - if not version: - raise click.UsageError("Could not find image from config, please specify a value for ``--version``") - - ctx.obj[CTX_PROJECT] = project - ctx.obj[CTX_DOMAIN] = domain - ctx.obj[CTX_VERSION] = version - ctx.obj[CTX_TEST] = test - _os.environ[_PROJECT.env_var] = project - _os.environ[_DOMAIN.env_var] = domain - _os.environ[_VERSION.env_var] = ctx.obj[CTX_VERSION] - - -@click.command("tasks") -@click.option( - "-v", - "--version", - type=str, - help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", -) -@click.pass_context -def tasks(ctx, version=None): - """ - Only register tasks. - """ - project = ctx.obj[CTX_PROJECT] - domain = ctx.obj[CTX_DOMAIN] - test = ctx.obj[CTX_TEST] - pkgs = ctx.obj[CTX_PACKAGES] - - version = version or ctx.obj[CTX_VERSION] or _look_up_version_from_image_tag(_IMAGE.get()) - register_tasks_only(project, domain, pkgs, test, version) - - -@click.command("workflows") -@click.option( - "-v", - "--version", - type=str, - help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", -) -@click.pass_context -def workflows(ctx, version=None): - """ - Register both tasks and workflows. Also create and register a default launch plan for all workflows. - """ - project = ctx.obj[CTX_PROJECT] - domain = ctx.obj[CTX_DOMAIN] - test = ctx.obj[CTX_TEST] - pkgs = ctx.obj[CTX_PACKAGES] - - version = version or ctx.obj[CTX_VERSION] or _look_up_version_from_image_tag(_IMAGE.get()) - register_all(project, domain, pkgs, test, version) - - -register.add_command(tasks) -register.add_command(workflows) diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index f02526a5dd..5fd68aebc4 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -14,24 +14,21 @@ import flytekit as _flytekit from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES -from flytekit.common import utils as _utils -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions.scopes import system_entry_point -from flytekit.common.exceptions.user import FlyteValidationException -from flytekit.common.tasks import task as _sdk_task -from flytekit.common.translator import get_serializable -from flytekit.common.utils import write_proto_to_file as _write_proto_to_file from flytekit.configuration import internal as _internal_config from flytekit.core import context_manager as flyte_context from flytekit.core.base_task import PythonTask from flytekit.core.launch_plan import LaunchPlan from flytekit.core.workflow import WorkflowBase +from flytekit.exceptions.scopes import system_entry_point +from flytekit.exceptions.user import FlyteValidationException from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import compute_digest as _compute_digest from flytekit.tools.fast_registration import filter_tar_file_fn as _filter_tar_file_fn -from flytekit.tools.module_loader import iterate_registerable_entities_in_order +from flytekit.tools.module_loader import trigger_loading +from flytekit.tools.translator import get_serializable # Identifier fields use placeholders for registration-time substitution. # Additional fields, such as auth and the raw output data prefix have more complex structures @@ -59,42 +56,6 @@ class SerializationMode(_Enum): FAST = 1 -@system_entry_point -def serialize_tasks_only(pkgs, folder=None): - """ - :param list[Text] pkgs: - :param Text folder: - - :return: - """ - # m = module (i.e. python file) - # k = value of dir(m), type str - # o = object (e.g. SdkWorkflow) - loaded_entities = [] - for m, k, o in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}): - name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) - _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier( - o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER - ) - loaded_entities.append(o) - - zero_padded_length = _determine_text_chars(len(loaded_entities)) - for i, entity in enumerate(loaded_entities): - serialized = entity.serialize() - fname_index = str(i).zfill(zero_padded_length) - fname = "{}_{}.pb".format(fname_index, entity._id.name) - click.echo(" Writing {} to\n {}".format(entity._id, fname)) - if folder: - fname = _os.path.join(folder, fname) - _write_proto_to_file(serialized, fname) - - identifier_fname = "{}_{}.identifier.pb".format(fname_index, entity._id.name) - if folder: - identifier_fname = _os.path.join(folder, identifier_fname) - _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname) - - def _should_register_with_admin(entity) -> bool: """ This is used in the code below. The translator.py module produces lots of objects (namely nodes and BranchNodes) @@ -140,8 +101,9 @@ def get_registrable_entities(ctx: flyte_context.FlyteContext) -> typing.List: serializable_tasks: typing.List[task_models.TaskSpec] = [ entity for entity in entities_to_be_serialized if isinstance(entity, task_models.TaskSpec) ] - # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same metadata identifiers - # (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate tasks are considered invalid at registration + # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same + # metadata identifiers (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate + # tasks are considered invalid at registration # time and usually indicate user error, so we catch this common mistake at serialization time. duplicate_tasks = _find_duplicate_tasks(serializable_tasks) if len(duplicate_tasks) > 0: @@ -211,9 +173,6 @@ def serialize_all( :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ - # m = module (i.e. python file) - # k = value of dir(m), type str - # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path @@ -242,29 +201,9 @@ def serialize_all( ) ctx = flyte_context.FlyteContextManager.current_context().with_serialization_settings(serialization_settings) with flyte_context.FlyteContextManager.with_context(ctx) as ctx: - old_style_entities = [] - # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module - # function that this iterate calls only works on legacy objects - for m, k, o in iterate_registerable_entities_in_order(pkgs, local_source_root=local_source_root): - name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) - _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier( - o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER - ) - old_style_entities.append(o) - - serialized_old_style_entities = [] - for entity in old_style_entities: - if entity.has_registered: - _logging.info(f"Skipping entity {entity.id} because already registered") - continue - serialized_old_style_entities.append(entity.serialize()) - + trigger_loading(pkgs, local_source_root=local_source_root) click.echo(f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows") - - new_api_model_values = get_registrable_entities(ctx) - - loaded_entities = serialized_old_style_entities + new_api_model_values + loaded_entities = get_registrable_entities(ctx) if folder is None: folder = "." persist_registrable_entities(loaded_entities, folder) @@ -349,18 +288,6 @@ def serialize(ctx, image, local_source_root, in_container_config_path, in_contai ctx.obj[CTX_PYTHON_INTERPRETER] = sys.executable -@click.command("tasks") -@click.option("-f", "--folder", type=click.Path(exists=True)) -@click.pass_context -def tasks(ctx, folder=None): - pkgs = ctx.obj[CTX_PACKAGES] - - if folder: - click.echo(f"Writing output to {folder}") - - serialize_tasks_only(pkgs, folder) - - @click.command("workflows") # For now let's just assume that the directory needs to exist. If you're docker run -v'ing, docker will create the # directory for you so it shouldn't be a problem. @@ -423,7 +350,5 @@ def fast_workflows(ctx, folder=None): fast.add_command(fast_workflows) - -serialize.add_command(tasks) serialize.add_command(workflows) serialize.add_command(fast) diff --git a/flytekit/common/component_nodes.py b/flytekit/common/component_nodes.py deleted file mode 100644 index ea39a28dea..0000000000 --- a/flytekit/common/component_nodes.py +++ /dev/null @@ -1,157 +0,0 @@ -import logging as _logging - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.models.core import workflow as _workflow_model - - -class SdkTaskNode(_workflow_model.TaskNode, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, sdk_task): - """ - :param flytekit.common.tasks.task.SdkTask sdk_task: - """ - self._sdk_task = sdk_task - super(SdkTaskNode, self).__init__(None) - - @property - def reference_id(self): - """ - A globally unique identifier for the task. - :rtype: flytekit.models.core.identifier.Identifier - """ - return self._sdk_task.id - - @property - def sdk_task(self): - """ - :rtype: flytekit.common.tasks.task.SdkTask - """ - return self._sdk_task - - @classmethod - def promote_from_model(cls, base_model, tasks): - """ - Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it from the - engine. - - :param flytekit.models.core.workflow.TaskNode base_model: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: - :rtype: SdkTaskNode - """ - from flytekit.common.tasks import task as _task - - if base_model.reference_id in tasks: - t = tasks[base_model.reference_id] - _logging.info(f"Found existing task template for {t.id}, will not retrieve from Admin") - sdk_task = _task.SdkTask.promote_from_model(t) - sdk_task._has_registered = True - return cls(sdk_task) - - # If not found, fetch it from Admin - _logging.debug("Fetching task template for {} from Admin".format(base_model.reference_id)) - project = base_model.reference_id.project - domain = base_model.reference_id.domain - name = base_model.reference_id.name - version = base_model.reference_id.version - sdk_task = _task.SdkTask.fetch(project, domain, name, version) - return cls(sdk_task) - - -class SdkWorkflowNode(_workflow_model.WorkflowNode, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, sdk_workflow=None, sdk_launch_plan=None): - """ - :param flytekit.common.workflow.SdkWorkflow sdk_workflow: - :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: - """ - if sdk_workflow and sdk_launch_plan: - raise _system_exceptions.FlyteSystemException( - "SdkWorkflowNode cannot be called with both a workflow and " - "a launchplan specified, please pick one. WF: {} LP: {}", - sdk_workflow, - sdk_launch_plan, - ) - - self._sdk_workflow = sdk_workflow - self._sdk_launch_plan = sdk_launch_plan - sdk_wf_id = sdk_workflow.id if sdk_workflow else None - sdk_lp_id = sdk_launch_plan.id if sdk_launch_plan else None - super(SdkWorkflowNode, self).__init__(launchplan_ref=sdk_lp_id, sub_workflow_ref=sdk_wf_id) - - def __repr__(self): - """ - :rtype: Text - """ - if self.sdk_workflow is not None: - return "SdkWorkflowNode with workflow: {}".format(self.sdk_workflow) - return "SdkWorkflowNode with launch plan: {}".format(self.sdk_launch_plan) - - @property - def launchplan_ref(self): - """ - [Optional] A globally unique identifier for the launch plan. Should map to Admin. - :rtype: flytekit.models.core.identifier.Identifier - """ - return self._sdk_launch_plan.id if self._sdk_launch_plan else None - - @property - def sub_workflow_ref(self): - """ - [Optional] Reference to a subworkflow, that should be defined with the compiler context. - :rtype: flytekit.models.core.identifier.Identifier - """ - return self._sdk_workflow.id if self._sdk_workflow else None - - @property - def sdk_launch_plan(self): - """ - :rtype: flytekit.common.launch_plan.SdkLaunchPlan - """ - return self._sdk_launch_plan - - @property - def sdk_workflow(self): - """ - :rtype: flytekit.common.workflow.SdkWorkflow - """ - return self._sdk_workflow - - @classmethod - def promote_from_model(cls, base_model, sub_workflows, tasks): - """ - :param flytekit.models.core.workflow.WorkflowNode base_model: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.workflow.WorkflowTemplate] - sub_workflows: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: - :rtype: SdkWorkflowNode - """ - # put the import statement here to prevent circular dependency error - from flytekit.common import launch_plan as _launch_plan - from flytekit.common import workflow as _workflow - - project = base_model.reference.project - domain = base_model.reference.domain - name = base_model.reference.name - version = base_model.reference.version - if base_model.launchplan_ref is not None: - sdk_launch_plan = _launch_plan.SdkLaunchPlan.fetch(project, domain, name, version) - return cls(sdk_launch_plan=sdk_launch_plan) - elif base_model.sub_workflow_ref is not None: - # The workflow templates for sub-workflows should have been included in the original response - if base_model.reference in sub_workflows: - sw = sub_workflows[base_model.reference] - promoted = _workflow.SdkWorkflow.promote_from_model(sw, sub_workflows=sub_workflows, tasks=tasks) - return cls(sdk_workflow=promoted) - - # If not found for some reason, fetch it from Admin again. - # The reason there is a warning here but not for tasks is because sub-workflows should always be passed - # along. Ideally subworkflows are never even registered with Admin, so fetching from Admin ideally doesn't - # return anything. - _logging.warning( - "Your subworkflow with id {} is not included in the promote call.".format(base_model.reference) - ) - sdk_workflow = _workflow.SdkWorkflow.fetch(project, domain, name, version) - return cls(sdk_workflow=sdk_workflow) - else: - raise _system_exceptions.FlyteSystemException( - "Bad workflow node model, neither subworkflow nor " "launchplan specified." - ) diff --git a/flytekit/common/exceptions/__init__.py b/flytekit/common/exceptions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/common/interface.py b/flytekit/common/interface.py deleted file mode 100644 index c8f5160e15..0000000000 --- a/flytekit/common/interface.py +++ /dev/null @@ -1,163 +0,0 @@ -import six as _six - -from flytekit.common import promise as _promise -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import containers as _containers -from flytekit.common.types import helpers as _type_helpers -from flytekit.common.types import primitives as _primitives -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models - - -class BindingData(_literal_models.BindingData, metaclass=_sdk_bases.ExtendedSdkType): - @staticmethod - def _has_sub_bindings(m): - """ - :param dict[Text,T] or list[T]: - :rtype: bool - """ - for v in _six.itervalues(m) if isinstance(m, dict) else m: - if isinstance(v, (list, dict)) and BindingData._has_sub_bindings(v): - return True - elif isinstance(v, (_promise.Input, _promise.NodeOutput)): - return True - return False - - @classmethod - def promote_from_model(cls, model): - """ - :param flytekit.models.literals.BindingData model: - :rtype: BindingData - """ - return cls( - scalar=model.scalar, - collection=model.collection, - promise=model.promise, - map=model.map, - ) - - @classmethod - def from_python_std(cls, literal_type, t_value, upstream_nodes=None): - """ - :param flytekit.models.types.LiteralType literal_type: - :param T t_value: - :param list[flytekit.common.nodes.SdkNode] upstream_nodes: [Optional] Keeps track of the nodes upstream, - if applicable. - :rtype: BindingData - """ - scalar = None - collection = None - promise = None - map = None - downstream_sdk_type = _type_helpers.get_sdk_type_from_literal_type(literal_type) - if isinstance(t_value, _promise.Input): - if not downstream_sdk_type.is_castable_from(t_value.sdk_type): - _user_exceptions.FlyteTypeException( - t_value.sdk_type, - downstream_sdk_type, - additional_msg="When binding workflow input: {}".format(t_value), - ) - promise = t_value.promise - elif isinstance(t_value, _promise.NodeOutput): - if not downstream_sdk_type.is_castable_from(t_value.sdk_type): - _user_exceptions.FlyteTypeException( - t_value.sdk_type, - downstream_sdk_type, - additional_msg="When binding node output: {}".format(t_value), - ) - promise = t_value - if upstream_nodes is not None: - upstream_nodes.append(t_value.sdk_node) - elif isinstance(t_value, list): - if not issubclass(downstream_sdk_type, _containers.ListImpl): - raise _user_exceptions.FlyteTypeException( - type(t_value), - downstream_sdk_type, - received_value=t_value, - additional_msg="Cannot bind a list to a non-list type.", - ) - collection = _literal_models.BindingDataCollection( - [ - BindingData.from_python_std( - downstream_sdk_type.sub_type.to_flyte_literal_type(), - v, - upstream_nodes=upstream_nodes, - ) - for v in t_value - ] - ) - elif isinstance(t_value, dict) and ( - not issubclass(downstream_sdk_type, _primitives.Generic) or BindingData._has_sub_bindings(t_value) - ): - # TODO: This behavior should be embedded in the type engine. Someone should be able to alter behavior of - # TODO: binding logic by injecting their own type engine. The same goes for the list check above. - raise NotImplementedError("TODO: Cannot use map bindings at the moment") - else: - sdk_value = downstream_sdk_type.from_python_std(t_value) - scalar = sdk_value.scalar - collection = sdk_value.collection - map = sdk_value.map - return cls(scalar=scalar, collection=collection, map=map, promise=promise) - - -class TypedInterface(_interface_models.TypedInterface, metaclass=_sdk_bases.ExtendedSdkType): - @classmethod - def promote_from_model(cls, model): - """ - :param flytekit.models.interface.TypedInterface model: - :rtype: TypedInterface - """ - return cls(model.inputs, model.outputs) - - def create_bindings_for_inputs(self, map_of_bindings): - """ - :param dict[Text, T] map_of_bindings: This can be scalar primitives, it can be node output references, - lists, etc.. - :rtype: (list[flytekit.models.literals.Binding], list[flytekit.common.nodes.SdkNode]) - :raises: flytekit.common.exceptions.user.FlyteAssertion - """ - binding_data = dict() - all_upstream_nodes = list() - for k in sorted(self.inputs): - var = self.inputs[k] - if k not in map_of_bindings: - raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) - - binding_data[k] = BindingData.from_python_std( - var.type, map_of_bindings[k], upstream_nodes=all_upstream_nodes - ) - - extra_inputs = set(binding_data.keys()) ^ set(map_of_bindings.keys()) - if len(extra_inputs) > 0: - raise _user_exceptions.FlyteAssertion( - "Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs) - ) - - seen_nodes = set() - min_upstream = list() - for n in all_upstream_nodes: - if n not in seen_nodes: - seen_nodes.add(n) - min_upstream.append(n) - - return ( - [_literal_models.Binding(k, bd) for k, bd in _six.iteritems(binding_data)], - min_upstream, - ) - - def __repr__(self): - return "({inputs}) -> ({outputs})".format( - inputs=", ".join( - [ - "{}: {}".format(k, _type_helpers.get_sdk_type_from_literal_type(v.type)) - for k, v in _six.iteritems(self.inputs) - ] - ), - outputs=", ".join( - [ - "{}: {}".format(k, _type_helpers.get_sdk_type_from_literal_type(v.type)) - for k, v in _six.iteritems(self.outputs) - ] - ), - ) diff --git a/flytekit/common/launch_plan.py b/flytekit/common/launch_plan.py deleted file mode 100644 index adb9ff979d..0000000000 --- a/flytekit/common/launch_plan.py +++ /dev/null @@ -1,498 +0,0 @@ -import datetime as _datetime -import logging as _logging -import uuid as _uuid - -import six as _six -from deprecated import deprecated as _deprecated - -from flytekit.common import interface as _interface -from flytekit.common import nodes as _nodes -from flytekit.common import promise as _promises -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import workflow_execution as _workflow_execution -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import hash as _hash_mixin -from flytekit.common.mixins import launchable as _launchable_mixin -from flytekit.common.mixins import registerable as _registerable -from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import auth as _auth_config -from flytekit.configuration import sdk as _sdk_config -from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import common as _common_models -from flytekit.models import execution as _execution_models -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import literals as _literal_models -from flytekit.models import schedule as _schedule_model -from flytekit.models.core import identifier as _identifier_model -from flytekit.models.core import workflow as _workflow_models - - -class SdkLaunchPlan( - _launchable_mixin.LaunchableEntity, - _registerable.HasDependencies, - _registerable.RegisterableEntity, - _launch_plan_models.LaunchPlanSpec, - metaclass=_sdk_bases.ExtendedSdkType, -): - def __init__(self, *args, **kwargs): - super(SdkLaunchPlan, self).__init__(*args, **kwargs) - # Set all the attributes we expect this class to have - self._id = None - - # The interface is not set explicitly unless fetched in an engine context - self._interface = None - - @classmethod - def promote_from_model(cls, model) -> "SdkLaunchPlan": - """ - :param flytekit.models.launch_plan.LaunchPlanSpec model: - :rtype: SdkLaunchPlan - """ - return cls( - workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id), - default_inputs=_interface_models.ParameterMap( - { - k: _promises.Input.promote_from_model(v).rename_and_return_reference(k) - for k, v in _six.iteritems(model.default_inputs.parameters) - } - ), - fixed_inputs=model.fixed_inputs, - entity_metadata=model.entity_metadata, - labels=model.labels, - annotations=model.annotations, - auth_role=model.auth_role, - raw_output_data_config=model.raw_output_data_config, - max_parallelism=model.max_parallelism, - ) - - @_exception_scopes.system_entry_point - def register(self, project, domain, name, version): - """ - :param Text project: - :param Text domain: - :param Text name: - :param Text version: - """ - self.validate() - id_to_register = _identifier.Identifier( - _identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version - ) - client = _flyte_engine.get_client() - try: - client.create_launch_plan(id_to_register, self) - except _user_exceptions.FlyteEntityAlreadyExistsException: - pass - - self._id = id_to_register - self._has_registered = True - return str(self.id) - - @classmethod - @_exception_scopes.system_entry_point - def fetch(cls, project, domain, name, version=None): - """ - This function uses the engine loader to call create a hydrated task from Admin. - :param Text project: - :param Text domain: - :param Text name: - :param Text version: [Optional] If not set, the SDK will fetch the active launch plan for the given project, - domain, and name. - :rtype: SdkLaunchPlan - """ - from flytekit.common import workflow as _workflow - - launch_plan_id = _identifier.Identifier( - _identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version - ) - - if launch_plan_id.version: - lp = _flyte_engine.get_client().get_launch_plan(launch_plan_id) - else: - named_entity_id = _common_models.NamedEntityIdentifier( - launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name - ) - lp = _flyte_engine.get_client().get_active_launch_plan(named_entity_id) - - sdk_lp = cls.promote_from_model(lp.spec) - sdk_lp._id = lp.id - - # TODO: Add a test for this, and this function as a whole - wf_id = sdk_lp.workflow_id - lp_wf = _workflow.SdkWorkflow.fetch(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) - sdk_lp._interface = lp_wf.interface - sdk_lp._has_registered = True - return sdk_lp - - @_exception_scopes.system_entry_point - def serialize(self): - """ - Serializing a launch plan should produce an object similar to what the registration step produces, - in preparation for actual registration to Admin. - - :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan - """ - return _launch_plan_models.LaunchPlan( - id=self.id, - spec=self, - closure=_launch_plan_models.LaunchPlanClosure( - state=None, - expected_inputs=_interface_models.ParameterMap({}), - expected_outputs=_interface_models.VariableMap({}), - ), - ).to_flyte_idl() - - @property - def id(self): - """ - :rtype: flytekit.common.core.identifier.Identifier - """ - return self._id - - @property - def is_scheduled(self): - """ - :rtype: bool - """ - if self.entity_metadata.schedule.cron_expression: - return True - elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: - return True - elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: - return True - else: - return False - - @property - def auth_role(self): - """ - :rtype: flytekit.models.common.AuthRole - """ - fixed_auth = super(SdkLaunchPlan, self).auth_role - if fixed_auth is not None and ( - fixed_auth.assumable_iam_role is not None or fixed_auth.kubernetes_service_account is not None - ): - return fixed_auth - - assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() - kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() - - if not assumable_iam_role and _sdk_config.ROLE.get() is not None: - _logging.warning( - "Using deprecated `role` from config. Please update your config to use `assumable_iam_role` instead" - ) - assumable_iam_role = _sdk_config.ROLE.get() - return _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account, - ) - - @property - def workflow_id(self): - """ - :rtype: flytekit.common.core.identifier.Identifier - """ - return self._workflow_id - - @property - def interface(self): - """ - The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and - from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= - object and get a node. - :rtype: flytekit.common.interface.TypedInterface - """ - return self._interface - - @property - def resource_type(self): - """ - Integer from _identifier.ResourceType enum - :rtype: int - """ - return _identifier_model.ResourceType.LAUNCH_PLAN - - @property - def entity_type_text(self): - """ - :rtype: Text - """ - return "Launch Plan" - - @property - def raw_output_data_config(self): - """ - :rtype: flytekit.models.common.RawOutputDataConfig - """ - raw_output_data_config = super(SdkLaunchPlan, self).raw_output_data_config - if raw_output_data_config is not None and raw_output_data_config.output_location_prefix != "": - return raw_output_data_config - - # If it was not set explicitly then let's use the value found in the configuration. - return _common_models.RawOutputDataConfig(_auth_config.RAW_OUTPUT_DATA_PREFIX.get()) - - @_exception_scopes.system_entry_point - def validate(self): - # TODO: Validate workflow is satisfied - pass - - @_exception_scopes.system_entry_point - def update(self, state): - """ - :param int state: Enum value from flytekit.models.launch_plan.LaunchPlanState - """ - if not self.id: - raise _user_exceptions.FlyteAssertion( - "Failed to update launch plan because the launch plan's ID is not set. Please call register to fetch " - "or register the identifier first" - ) - return _flyte_engine.get_client().update_launch_plan(self.id, state) - - def _python_std_input_map_to_literal_map(self, inputs): - """ - :param dict[Text,Any] inputs: A dictionary of Python standard inputs that will be type-checked and compiled - to a LiteralMap - :rtype: flytekit.models.literals.LiteralMap - """ - return _type_helpers.pack_python_std_map_to_literal_map( - inputs, - {k: user_input.sdk_type for k, user_input in _six.iteritems(self.default_inputs.parameters) if k in inputs}, - ) - - @_deprecated(reason="Use launch_with_literals instead", version="0.9.0") - def execute_with_literals( - self, - project, - domain, - literal_inputs, - name=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - ): - """ - Deprecated. - """ - return self.launch_with_literals( - project, - domain, - literal_inputs, - name, - notification_overrides, - label_overrides, - annotation_overrides, - ) - - @_exception_scopes.system_entry_point - def launch_with_literals( - self, - project, - domain, - literal_inputs, - name=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - auth_role=None, - ): - """ - Executes the launch plan and returns the execution identifier. This version of execution is meant for when - you already have a LiteralMap of inputs. - - :param Text project: - :param Text domain: - :param flytekit.models.literals.LiteralMap literal_inputs: Inputs to the execution. - :param Text name: [Optional] If specified, an execution will be created with this name. Note: the name must - be unique within the context of the project and domain. - :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these - are the notifications that will be honored for this execution. An empty list signals to disable all - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution - :param flytekit.models.common.AuthRole auth_role: - """ - # Kubernetes requires names starting with an alphabet for some resources. - name = name or "f" + _uuid.uuid4().hex[:19] - disable_all = notification_overrides == [] - if disable_all: - notification_overrides = None - else: - notification_overrides = _execution_models.NotificationList(notification_overrides or []) - disable_all = None - - client = _flyte_engine.get_client() - try: - exec_id = client.create_execution( - project, - domain, - name, - _execution_models.ExecutionSpec( - self.id, - _execution_models.ExecutionMetadata( - _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - "sdk", # TODO: get principle - 0, # TODO: Detect nesting - ), - notifications=notification_overrides, - disable_all=disable_all, - labels=label_overrides, - annotations=annotation_overrides, - auth_role=auth_role, - ), - literal_inputs, - ) - except _user_exceptions.FlyteEntityAlreadyExistsException: - exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) - execution = client.get_execution(exec_id) - return _workflow_execution.SdkWorkflowExecution.promote_from_model(execution) - - @_exception_scopes.system_entry_point - def __call__(self, *args, **input_map): - """ - :param list[T] args: Do not specify. Kwargs only are supported for this function. - :param dict[Text,T] input_map: Map of inputs. Can be statically defined or OutputReference links. - :rtype: flytekit.common.nodes.SdkNode - """ - if len(args) > 0: - raise _user_exceptions.FlyteAssertion( - "When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only. We " - "detected {} positional args.".format(len(args)) - ) - - # Take the default values from the launch plan - default_inputs = {k: v.sdk_default for k, v in _six.iteritems(self.default_inputs.parameters) if not v.required} - default_inputs.update(input_map) - - bindings, upstream_nodes = self.interface.create_bindings_for_inputs(default_inputs) - - return _nodes.SdkNode( - id=None, - metadata=_workflow_models.NodeMetadata("", _datetime.timedelta(), _literal_models.RetryStrategy(0)), - bindings=sorted(bindings, key=lambda b: b.var), - upstream_nodes=upstream_nodes, - sdk_launch_plan=self, - ) - - def __repr__(self): - """ - :rtype: Text - """ - return "SdkLaunchPlan(ID: {} Interface: {} WF ID: {})".format(self.id, self.interface, self.workflow_id) - - -# The difference between this and the SdkLaunchPlan class is that this runnable class is supposed to only be used for -# launch plans loaded alongside the current Python interpreter. -class SdkRunnableLaunchPlan(_hash_mixin.HashOnReferenceMixin, SdkLaunchPlan): - def __init__( - self, - sdk_workflow, - default_inputs=None, - fixed_inputs=None, - role=None, - schedule=None, - notifications=None, - labels=None, - annotations=None, - auth_role=None, - raw_output_data_config=None, - ): - """ - :param flytekit.common.local_workflow.SdkRunnableWorkflow sdk_workflow: - :param dict[Text,flytekit.common.promise.Input] default_inputs: - :param dict[Text,Any] fixed_inputs: These inputs will be fixed and not need to be set when executing this - launch plan. - :param Text role: Deprecated. IAM role to execute this launch plan with. - :param flytekit.models.schedule.Schedule: Schedule to apply to this workflow. - :param list[flytekit.models.common.Notification]: List of notifications to apply to this launch plan. - :param flytekit.models.common.Labels labels: Any custom kubernetes labels to apply to workflows executed by this - launch plan. - :param flytekit.models.common.Annotations annotations: Any custom kubernetes annotations to apply to workflows - executed by this launch plan. - Any custom kubernetes annotations to apply to workflows executed by this launch plan. - :param flytekit.models.common.Authrole auth_role: The auth method with which to execute the workflow. - :param flytekit.models.common.RawOutputDataConfig raw_output_data_config: Config for offloading data - """ - if role and auth_role: - raise ValueError("Cannot set both role and auth. Role is deprecated, use auth instead.") - - fixed_inputs = fixed_inputs or {} - default_inputs = default_inputs or {} - - if role: - auth_role = _common_models.AuthRole(assumable_iam_role=role) - - # The constructor for SdkLaunchPlan sets the id to None anyways so we don't bother passing in an ID. The ID - # should be set in one of three places, - # 1) When the object is registered (in the code above) - # 2) By the dynamic task code after this runnable object has already been __call__'ed. The SdkNode produced - # maintains a link to this object and will set the ID according to the configuration variables present. - # 3) When SdkLaunchPlan.fetch() is run - super(SdkRunnableLaunchPlan, self).__init__( - None, - _launch_plan_models.LaunchPlanMetadata( - schedule=schedule or _schedule_model.Schedule(""), - notifications=notifications or [], - ), - _interface_models.ParameterMap(default_inputs), - _type_helpers.pack_python_std_map_to_literal_map( - fixed_inputs, - { - k: _type_helpers.get_sdk_type_from_literal_type(var.type) - for k, var in _six.iteritems(sdk_workflow.interface.inputs) - if k in fixed_inputs - }, - ), - labels or _common_models.Labels({}), - annotations or _common_models.Annotations({}), - auth_role, - raw_output_data_config or _common_models.RawOutputDataConfig(""), - ) - self._interface = _interface.TypedInterface( - {k: v.var for k, v in _six.iteritems(default_inputs)}, - sdk_workflow.interface.outputs, - ) - self._upstream_entities = {sdk_workflow} - self._sdk_workflow = sdk_workflow - - @classmethod - def from_flyte_idl(cls, _): - raise _user_exceptions.FlyteAssertion( - "An SdkRunnableLaunchPlan must be created from a reference to local Python code only." - ) - - @classmethod - def promote_from_model(cls, model): - raise _user_exceptions.FlyteAssertion( - "An SdkRunnableLaunchPlan must be created from a reference to local Python code only." - ) - - @classmethod - @_exception_scopes.system_entry_point - def fetch(cls, project, domain, name, version=None): - """ - This function uses the engine loader to call create a hydrated task from Admin. - :param Text project: - :param Text domain: - :param Text name: - :param Text version: - :rtype: SdkRunnableLaunchPlan - """ - raise _user_exceptions.FlyteAssertion( - "An SdkRunnableLaunchPlan must be created from a reference to local Python code only." - ) - - @property - def workflow_id(self): - """ - :rtype: flytekit.common.core.identifier.Identifier - """ - return self._sdk_workflow.id - - def __repr__(self): - """ - :rtype: Text - """ - return "SdkRunnableLaunchPlan(ID: {} Interface: {} WF ID: {})".format(self.id, self.interface, self.workflow_id) diff --git a/flytekit/common/local_workflow.py b/flytekit/common/local_workflow.py deleted file mode 100644 index eb2067578f..0000000000 --- a/flytekit/common/local_workflow.py +++ /dev/null @@ -1,388 +0,0 @@ -import uuid as _uuid -from typing import Any, Dict, List - -import six as _six -from six.moves import queue as _queue - -from flytekit.common import interface as _interface -from flytekit.common import launch_plan as _launch_plan -from flytekit.common import nodes as _nodes -from flytekit.common import promise as _promise -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import helpers as _type_helpers -from flytekit.common.workflow import SdkWorkflow -from flytekit.configuration import internal as _internal_config -from flytekit.models import common as _common_models -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models -from flytekit.models import schedule as _schedule_models -from flytekit.models.core import identifier as _identifier_model -from flytekit.models.core import workflow as _workflow_models - - -# Local-only wrapper around binding data and variables. Note that the Output object used by the end user is a yet -# another layer on top of this. -class Output(object): - def __init__(self, name, value, sdk_type=None, help=None): - """ - :param Text name: - :param T value: - :param U sdk_type: If specified, the value provided must cast to this type. Normally should be an instance of - flytekit.common.types.base_sdk_types.FlyteSdkType. But could also be something like: - - list[flytekit.common.types.base_sdk_types.FlyteSdkType], - dict[flytekit.common.types.base_sdk_types.FlyteSdkType,flytekit.common.types.base_sdk_types.FlyteSdkType], - (flytekit.common.types.base_sdk_types.FlyteSdkType, flytekit.common.types.base_sdk_types.FlyteSdkType, ...) - """ - if sdk_type is None: - # This syntax didn't work for some reason: sdk_type = sdk_type or Output._infer_type(value) - sdk_type = Output._infer_type(value) - sdk_type = _type_helpers.python_std_to_sdk_type(sdk_type) - - self._binding_data = _interface.BindingData.from_python_std(sdk_type.to_flyte_literal_type(), value) - self._var = _interface_models.Variable(sdk_type.to_flyte_literal_type(), help or "") - self._name = name - - def rename_and_return_reference(self, new_name): - self._name = new_name - return self - - @staticmethod - def _infer_type(value): - # TODO: Infer types - raise NotImplementedError( - "Currently the SDK cannot infer a workflow output type, so please use the type kwarg " - "when instantiating an output." - ) - - @property - def name(self): - """ - :rtype: Text - """ - return self._name - - @property - def binding_data(self): - """ - :rtype: flytekit.models.literals.BindingData - """ - return self._binding_data - - @property - def var(self): - """ - :rtype: flytekit.models.interface.Variable - """ - return self._var - - -class SdkRunnableWorkflow(SdkWorkflow): - """ - Wrapper class for workflows defined using Python, written in a Flyte workflow repo. This class is misnamed. It is - more appropriately called PythonWorkflow. The reason we are calling it SdkRunnableWorkflow instead is merely in - keeping with the established convention in other parts of this codebase. We will likely change this naming scheme - entirely before a 1.0 release. - - Being "runnable" or not "runnable" is not a distinction we care to make at this point. Do not read into it, - pretend it's not there. The purpose of this class is merely to differentiate between - i) A workflow object, as created by a user's workflow code, using the @workflow_class decorator for instance. If - you have one of these classes, it means you have the actual Python code available in the Python process you - are running. - ii) The SdkWorkflow object, which represents a workflow as retrieved from Flyte Admin. Anyone with access - to Admin, can instantiate an SdkWorkflow object by 'fetch'ing it. You don't need to have any of - the actual code checked out. SdkWorkflow's are effectively then the control plane model of a workflow, - represented as a Python object. - """ - - def __init__( - self, - inputs: List[_promise.Input], - nodes: List[_nodes.SdkNode], - interface, - output_bindings, - id=None, - metadata=None, - metadata_defaults=None, - disable_default_launch_plan=False, - ): - """ - :param list[flytekit.common.nodes.SdkNode] nodes: - :param flytekit.models.interface.TypedInterface interface: Defines a strongly typed interface for the - Workflow (inputs, outputs). This can include some optional parameters. - :param list[flytekit.models.literals.Binding] output_bindings: A list of output bindings that specify how to construct - workflow outputs. Bindings can pull node outputs or specify literals. All workflow outputs specified in - the interface field must be bound - in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes - to execute successfully in order to bind final outputs. - :param flytekit.models.core.identifier.Identifier id: This is an autogenerated id by the system. The id is - globally unique across Flyte. - :param WorkflowMetadata metadata: This contains information on how to run the workflow. - :param flytekit.models.core.workflow.WorkflowMetadataDefaults metadata_defaults: Defaults to be passed - to nodes contained within workflow. - :param bool disable_default_launch_plan: Determines whether to create a default launch plan for the workflow. - """ - # Save the promise.Input objects for future use. - self._user_inputs = inputs - - # Set optional settings - id = ( - id - if id is not None - else _identifier.Identifier( - _identifier_model.ResourceType.WORKFLOW, - _internal_config.PROJECT.get(), - _internal_config.DOMAIN.get(), - _uuid.uuid4().hex, - _internal_config.VERSION.get(), - ) - ) - metadata = metadata if metadata is not None else _workflow_models.WorkflowMetadata() - metadata_defaults = ( - metadata_defaults if metadata_defaults is not None else _workflow_models.WorkflowMetadataDefaults() - ) - - super(SdkRunnableWorkflow, self).__init__( - nodes=nodes, - interface=interface, - output_bindings=output_bindings, - id=id, - metadata=metadata, - metadata_defaults=metadata_defaults, - ) - - # Set this last as it's set in constructor - self._upstream_entities = set(n.executable_sdk_object for n in nodes) - self._should_create_default_launch_plan = not disable_default_launch_plan - - @property - def should_create_default_launch_plan(self): - """ - Determines whether registration flow should create a default launch plan for this workflow or not. - :rtype: bool - """ - return self._should_create_default_launch_plan - - def __call__(self, *args, **input_map): - # Take the default values from the Inputs - compiled_inputs = {v.name: v.sdk_default for v in self.user_inputs if not v.sdk_required} - compiled_inputs.update(input_map) - - return super().__call__(*args, **compiled_inputs) - - @classmethod - def construct_from_class_definition( - cls, - inputs: List[_promise.Input], - outputs: List[Output], - nodes: List[_nodes.SdkNode], - metadata: _workflow_models.WorkflowMetadata = None, - metadata_defaults: _workflow_models.WorkflowMetadataDefaults = None, - disable_default_launch_plan: bool = False, - ) -> "SdkRunnableWorkflow": - """ - This constructor is here to provide backwards-compatibility for class-defined Workflows - - :param list[flytekit.common.promise.Input] inputs: - :param list[Output] outputs: - :param list[flytekit.common.nodes.SdkNode] nodes: - :param WorkflowMetadata metadata: This contains information on how to run the workflow. - :param flytekit.models.core.workflow.WorkflowMetadataDefaults metadata_defaults: Defaults to be passed - to nodes contained within workflow. - :param bool disable_default_launch_plan: Determines whether to create a default launch plan for the workflow or not. - - :rtype: SdkRunnableWorkflow - """ - for n in nodes: - for upstream in n.upstream_nodes: - if upstream.id is None: - raise _user_exceptions.FlyteAssertion( - "Some nodes contained in the workflow were not found in the workflow description. Please " - "ensure all nodes are either assigned to attributes within the class or an element in a " - "list, dict, or tuple which is stored as an attribute in the class." - ) - - id = _identifier.Identifier( - _identifier_model.ResourceType.WORKFLOW, - _internal_config.PROJECT.get(), - _internal_config.DOMAIN.get(), - _uuid.uuid4().hex, - _internal_config.VERSION.get(), - ) - interface = _interface.TypedInterface({v.name: v.var for v in inputs}, {v.name: v.var for v in outputs}) - - output_bindings = [_literal_models.Binding(v.name, v.binding_data) for v in outputs] - - return cls( - inputs=inputs, - nodes=nodes, - interface=interface, - output_bindings=output_bindings, - id=id, - metadata=metadata, - metadata_defaults=metadata_defaults, - disable_default_launch_plan=disable_default_launch_plan, - ) - - @property - def id(self): - return self._id - - @id.setter - def id(self, new_id): - self._id = new_id - - @property - def user_inputs(self) -> List[_promise.Input]: - """ - :rtype: list[flytekit.common.promise.Input] - """ - return self._user_inputs - - def create_launch_plan( - self, - default_inputs: Dict[str, _promise.Input] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_models.Schedule = None, - role: str = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - assumable_iam_role: str = None, - kubernetes_service_account: str = None, - raw_output_data_prefix: str = None, - ): - """ - This method will create a launch plan object that can execute this workflow. - :param dict[Text,flytekit.common.promise.Input] default_inputs: - :param dict[Text,T] fixed_inputs: - :param flytekit.models.schedule.Schedule schedule: A schedule on which to execute this launch plan. - :param Text role: Deprecated. Use assumable_iam_role instead. - :param list[flytekit.models.common.Notification] notifications: A list of notifications to enact by default for - this launch plan. - :param flytekit.models.common.Labels labels: - :param flytekit.models.common.Annotations annotations: - :param cls: This parameter can be used by users to define an extension of a launch plan to instantiate. The - class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan. - :param Text assumable_iam_role: The IAM role to execute the workflow with. - :param Text kubernetes_service_account: The kubernetes service account to execute the workflow with. - :param Text raw_output_data_prefix: Bucket for offloaded data - :rtype: flytekit.common.launch_plan.SdkRunnableLaunchPlan - """ - # TODO: Actually ensure the parameters conform. - if role and (assumable_iam_role or kubernetes_service_account): - raise ValueError("Cannot set both role and auth. Role is deprecated, use auth instead.") - fixed_inputs = fixed_inputs or {} - merged_default_inputs = {v.name: v for v in self.user_inputs if v.name not in fixed_inputs} - merged_default_inputs.update(default_inputs or {}) - - if role: - assumable_iam_role = role # For backwards compatibility - auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account, - ) - - raw_output_config = _common_models.RawOutputDataConfig(raw_output_data_prefix or "") - - return _launch_plan.SdkRunnableLaunchPlan( - sdk_workflow=self, - default_inputs={ - k: user_input.rename_and_return_reference(k) for k, user_input in _six.iteritems(merged_default_inputs) - }, - fixed_inputs=fixed_inputs, - schedule=schedule, - notifications=notifications, - labels=labels, - annotations=annotations, - auth_role=auth_role, - raw_output_data_config=raw_output_config, - ) - - -def build_sdk_workflow_from_metaclass(metaclass, on_failure=None, disable_default_launch_plan=False, cls=None): - """ - :param T metaclass: This is the user-defined workflow class, prior to decoration. - :param on_failure flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy: [Optional] The execution policy - when the workflow detects a failure. - :param bool disable_default_launch_plan: Determines whether to create a default launch plan for the workflow or not. - :param cls: This is the class that will be instantiated from the inputs, outputs, and nodes. This will be used - by users extending the base Flyte programming model. If set, it must be a subclass of PythonWorkflow. - - :rtype: SdkRunnableWorkflow - """ - inputs, outputs, nodes = _discover_workflow_components(metaclass) - metadata = _workflow_models.WorkflowMetadata(on_failure=on_failure if on_failure else None) - - return (cls or SdkRunnableWorkflow).construct_from_class_definition( - inputs=[i for i in sorted(inputs, key=lambda x: x.name)], - outputs=[o for o in sorted(outputs, key=lambda x: x.name)], - nodes=[n for n in sorted(nodes, key=lambda x: x.id)], - metadata=metadata, - disable_default_launch_plan=disable_default_launch_plan, - ) - - -def _discover_workflow_components(workflow_class): - """ - This task iterates over the attributes of a user-defined class in order to return a list of inputs, outputs and - nodes. - :param class workflow_class: User-defined class with task instances as attributes. - :rtype: (list[flytekit.common.promise.Input], list[Output], list[flytekit.common.nodes.SdkNode]) - """ - - inputs = [] - outputs = [] - nodes = [] - - to_visit_objs = _queue.Queue() - top_level_attributes = set() - for attribute_name in dir(workflow_class): - to_visit_objs.put((attribute_name, getattr(workflow_class, attribute_name))) - top_level_attributes.add(attribute_name) - - # For all task instances defined within the workflow, bind them to this specific workflow and hook-up to the - # engine (when available) - visited_obj_ids = set() - while not to_visit_objs.empty(): - attribute_name, current_obj = to_visit_objs.get() - - current_obj_id = id(current_obj) - if current_obj_id in visited_obj_ids: - continue - visited_obj_ids.add(current_obj_id) - - if isinstance(current_obj, _nodes.SdkNode): - # TODO: If an attribute name is on the form node_name[index], the resulting - # node name might not be correct. - nodes.append(current_obj.assign_id_and_return(attribute_name)) - elif isinstance(current_obj, _promise.Input): - if attribute_name is None or attribute_name not in top_level_attributes: - raise _user_exceptions.FlyteValueException( - attribute_name, - "Detected workflow input specified outside of top level.", - ) - inputs.append(current_obj.rename_and_return_reference(attribute_name)) - elif isinstance(current_obj, Output): - if attribute_name is None or attribute_name not in top_level_attributes: - raise _user_exceptions.FlyteValueException( - attribute_name, - "Detected workflow output specified outside of top level.", - ) - outputs.append(current_obj.rename_and_return_reference(attribute_name)) - elif isinstance(current_obj, list) or isinstance(current_obj, set) or isinstance(current_obj, tuple): - for idx, value in enumerate(current_obj): - to_visit_objs.put((_assign_indexed_attribute_name(attribute_name, idx), value)) - elif isinstance(current_obj, dict): - # Visit dictionary keys. - for key in current_obj.keys(): - to_visit_objs.put((_assign_indexed_attribute_name(attribute_name, key), key)) - # Visit dictionary values. - for key, value in _six.iteritems(current_obj): - to_visit_objs.put((_assign_indexed_attribute_name(attribute_name, key), value)) - return inputs, outputs, nodes - - -def _assign_indexed_attribute_name(attribute_name, index): - return "{}[{}]".format(attribute_name, index) diff --git a/flytekit/common/mixins/__init__.py b/flytekit/common/mixins/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/common/mixins/artifact.py b/flytekit/common/mixins/artifact.py deleted file mode 100644 index 41e5ebb0b7..0000000000 --- a/flytekit/common/mixins/artifact.py +++ /dev/null @@ -1,80 +0,0 @@ -import abc as _abc -import datetime as _datetime -import time as _time - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import common as _common_models - - -class ExecutionArtifact(object, metaclass=_common_models.FlyteABCMeta): - @_abc.abstractproperty - def inputs(self): - """ - Returns the inputs to the execution in the standard Python format as dictated by the type engine. - :rtype: dict[Text, T] - """ - pass - - @_abc.abstractproperty - def outputs(self): - """ - Returns the outputs to the execution in the standard Python format as dictated by the type engine. If the - execution ended in error or the execution is in progress, an exception will be raised. - :rtype: dict[Text, T] - """ - pass - - @_abc.abstractproperty - def error(self): - """ - If execution is in progress, raise an exception. Otherwise, return None if no error was present upon - reaching completion. - :rtype: flytekit.models.core.execution.ExecutionError or None - """ - pass - - @_abc.abstractproperty - def is_complete(self): - """ - Dictates whether or not the execution is complete. - :rtype: bool - """ - pass - - @_abc.abstractmethod - def sync(self): - """ - Syncs the state of the underlying execution artifact with the state observed by the platform. - :rtype: None - """ - pass - - @_abc.abstractmethod - def _sync_closure(self): - """ - Syncs the closure of the underlying execution artifact with the state observed by the platform. - :rtype: None - """ - pass - - def wait_for_completion(self, timeout=None, poll_interval=None): - """ - :param datetime.timedelta timeout: Amount of time to wait until the execution has completed before timing - out. If not set or set to None, this method will wait for infinite. - :param datetime.timedelta poll_interval: Duration to wait between polling for a completion update. - :rtype: None - """ - poll_interval = poll_interval or _datetime.timedelta(seconds=30) - if timeout is None: - time_to_give_up = _datetime.datetime.max - else: - time_to_give_up = _datetime.datetime.utcnow() + timeout - - self._sync_closure() - while _datetime.datetime.utcnow() < time_to_give_up: - if self.is_complete: - self.sync() - return - _time.sleep(poll_interval.total_seconds()) - self._sync_closure() - raise _user_exceptions.FlyteTimeout("Execution {} did not complete before timeout.".format(self)) diff --git a/flytekit/common/mixins/launchable.py b/flytekit/common/mixins/launchable.py deleted file mode 100644 index 110ba663af..0000000000 --- a/flytekit/common/mixins/launchable.py +++ /dev/null @@ -1,129 +0,0 @@ -import abc as _abc - -from deprecated import deprecated as _deprecated - - -class LaunchableEntity(object, metaclass=_abc.ABCMeta): - def launch( - self, - project, - domain, - inputs=None, - name=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - auth_role=None, - ): - """ - Creates a remote execution from the entity and returns the execution identifier. - This version of launch is meant for when inputs are specified as Python native types/structures. - - :param Text project: - :param Text domain: - :param dict[Text, Any] inputs: A dictionary of Python standard inputs that will be type-checked, then compiled - to a LiteralMap. - :param Text name: [Optional] If specified, an execution will be created with this name. Note: the name must - be unique within the context of the project and domain. - :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these - are the notifications that will be honored for this execution. An empty list signals to disable all - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: - :rtype: T - - """ - return self.launch_with_literals( - project, - domain, - self._python_std_input_map_to_literal_map(inputs or {}), - name=name, - notification_overrides=notification_overrides, - label_overrides=label_overrides, - annotation_overrides=annotation_overrides, - auth_role=auth_role, - ) - - @_deprecated(reason="Use launch instead", version="0.9.0") - def execute( - self, - project, - domain, - inputs=None, - name=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - ): - """ - Deprecated. - """ - return self.launch( - project, - domain, - inputs=inputs, - name=name, - notification_overrides=notification_overrides, - label_overrides=label_overrides, - annotation_overrides=annotation_overrides, - ) - - @_abc.abstractmethod - def _python_std_input_map_to_literal_map(self, inputs): - pass - - @_abc.abstractmethod - def launch_with_literals( - self, - project, - domain, - literal_inputs, - name=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - auth_role=None, - ): - """ - Executes the entity and returns the execution identifier. This version of execution is meant for when - you already have a LiteralMap of inputs. - - :param Text project: - :param Text domain: - :param flytekit.models.literals.LiteralMap literal_inputs: Inputs to the execution. - :param Text name: [Optional] If specified, an execution will be created with this name. Note: the name must - be unique within the context of the project and domain. - :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these - are the notifications that will be honored for this execution. An empty list signals to disable all - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: - :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier: - """ - pass - - @_deprecated(reason="Use launch_with_literals instead", version="0.9.0") - def execute_with_literals( - self, - project, - domain, - literal_inputs, - name=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - ): - """ - Deprecated. - """ - return self.launch_with_literals( - project, - domain, - literal_inputs, - name, - notification_overrides, - label_overrides, - annotation_overrides, - ) diff --git a/flytekit/common/mixins/registerable.py b/flytekit/common/mixins/registerable.py deleted file mode 100644 index 64fba47ef0..0000000000 --- a/flytekit/common/mixins/registerable.py +++ /dev/null @@ -1,195 +0,0 @@ -import abc as _abc -import importlib as _importlib -import inspect as _inspect -import logging as _logging -from typing import Set - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import utils as _utils -from flytekit.common.exceptions import system as _system_exceptions - - -class _InstanceTracker(_sdk_bases.ExtendedSdkType): - """ - This is either genius or terrible. Some of our tools iterate over modules and try to find Flyte entities - (Tasks, Workflows, Launch Plans) and then register them. However, if a task is imported via a command like: - - from package.module import some_task - - It is possible we will find a task reference twice, but then how do we know where it was defined? Ideally, we would - like to only register a task once and do so with the name where it is defined. This metaclass allows us to do this - by inspecting the call stack when __call__ is called on the metaclass (thus instantiating an object). - """ - - @staticmethod - def _find_instance_module(): - frame = _inspect.currentframe() - while frame: - if frame.f_code.co_name == "": - return frame.f_globals["__name__"] - frame = frame.f_back - return None - - def __call__(cls, *args, **kwargs): - o = super(_InstanceTracker, cls).__call__(*args, **kwargs) - o._instantiated_in = _InstanceTracker._find_instance_module() - return o - - -class FlyteEntity(object, metaclass=_sdk_bases.ExtendedSdkType): - @property - @_abc.abstractmethod - def resource_type(self): - """ - Integer from _identifier.ResourceType enum - :rtype: int - """ - pass - - @property - @_abc.abstractmethod - def entity_type_text(self): - """ - TODO: Rename to resource type text - :rtype: Text - """ - pass - - -class TrackableEntity(FlyteEntity, metaclass=_InstanceTracker): - def __init__(self, *args, **kwargs): - self._platform_valid_name = None - super(TrackableEntity, self).__init__(*args, **kwargs) - - @property - def instantiated_in(self): - """ - If found, we try to specify the module where the task was first instantiated. - :rtype: Optional[Text] - """ - return self._instantiated_in # Set in metaclass - - @property - def has_valid_name(self): - """ - :rtype: bool - """ - return self._platform_valid_name is not None and self._platform_valid_name != "" - - @property - def platform_valid_name(self): - """ - :rtype: Text - """ - return self._platform_valid_name - - def assign_name(self, name): - self._platform_valid_name = name - - def auto_assign_name(self): - """ - This function is a bit of trickster Python code that goes hand in hand with the _InstanceTracker metaclass - defined above. Thanks @matthewphsmith for this bit of ingenuity. - - For instance, if a user has code that looks like this: - - from some.other.module import wf - my_launch_plan = wf.create_launch_plan() - - @dynamic_task - def sample_task(wf_params): - yield my_launch_plan() - - This code means that we should have a launch plan with a name ending in "my_launch_plan", since that is the - name of the variable that the created launch plan gets assigned to. That is also the name that the launch plan - would be registered with. - - However, when the create_launch_plan() function runs, the Python interpreter has no idea where the created - object will be assigned to. It has no idea that the output of the create_launch_plan call is to be paired up - with a variable named "my_launch_plan". This function basically does this after the fact. Leveraging the - _instantiated_in field provided by the _InstanceTracker class above, this code will re-import the - module (ie Python file) that the object is in. Since it's already loaded, it's just retrieved from memory. - It then scans all objects in the module, and when an object match is found, it knows it's found the right - variable name. - - Just to drive the point home, this function is mostly needed for Launch Plans. Assuming that user code has: - - @python_task - def some_task() - - When Flytekit calls the module loader and loads the task, the name of the task is the name of the function - itself. It's known at time of creation. In contrast, when - - xyz = SomeWorkflow.create_launch_plan() - - is called, the name of the launch plan isn't known until after creation, it's not "SomeWorkflow", it's "xyz" - """ - _logging.debug("Running name auto assign") - m = _importlib.import_module(self.instantiated_in) - - for k in dir(m): - try: - if getattr(m, k) is self: - self._platform_valid_name = _utils.fqdn(m.__name__, k, entity_type=self.resource_type) - _logging.debug("Auto-assigning name to {}".format(self._platform_valid_name)) - return - except ValueError as err: - # Empty pandas dataframes behave weirdly here such that calling `m.df` raises: - # ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(), - # a.any() or a.all() - # Since dataframes aren't registrable entities to begin with we swallow any errors they raise and - # continue looping through m. - _logging.warning("Caught ValueError {} while attempting to auto-assign name".format(err)) - pass - - _logging.error("Could not auto-assign name") - raise _system_exceptions.FlyteSystemException("Error looking for object while auto-assigning name.") - - -class RegisterableEntity(TrackableEntity): - def __init__(self, *args, **kwargs): - self._has_registered = False - super(RegisterableEntity, self).__init__(*args, **kwargs) - - @_abc.abstractmethod - def register(self, project, domain, name, version): - """ - :param Text project: The project in which to register this task. - :param Text domain: The domain in which to register this task. - :param Text name: The name to give this task. - :param Text version: The version in which to register this task. - """ - pass - - @_abc.abstractmethod - def serialize(self): - """ - Registerable entities also are required to be serialized. This allows flytekit to separate serialization from - the network call to Admin (mostly at least, if a Launch Plan is fetched for instance as part of another - workflow, it will still hit Admin). - """ - pass - - @property - def has_registered(self) -> bool: - return self._has_registered - - -class HasDependencies(object): - """ - This interface is meant to describe Flyte entities that can have upstream dependencies. For instance, currently a - launch plan depends on the underlying workflow, and a workflow is dependent on its tasks, and other launch plans, - and subworkflows. - """ - - def __init__(self, *args, **kwargs): - self._upstream_entities = set() - super(HasDependencies, self).__init__(*args, **kwargs) - - @property - def upstream_entities(self) -> Set[RegisterableEntity]: - """ - Task, workflow, and launch plan that need to be registered in advance of this workflow. - :rtype: set[RegisterableEntity] - """ - return self._upstream_entities diff --git a/flytekit/common/nodes.py b/flytekit/common/nodes.py deleted file mode 100644 index fb68de2bc2..0000000000 --- a/flytekit/common/nodes.py +++ /dev/null @@ -1,463 +0,0 @@ -import abc as _abc -import logging as _logging -import os as _os - -import six as _six -from flyteidl.core import literals_pb2 as _literals_pb2 -from sortedcontainers import SortedDict as _SortedDict - -from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions -from flytekit.common import component_nodes as _component_nodes -from flytekit.common import constants as _constants -from flytekit.common import promise as _promise -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import utils as _common_utils -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import artifact as _artifact_mixin -from flytekit.common.mixins import hash as _hash_mixin -from flytekit.common.tasks import executions as _task_executions -from flytekit.common.types import helpers as _type_helpers -from flytekit.common.utils import _dnsify -from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import common as _common_models -from flytekit.models import literals as _literal_models -from flytekit.models import node_execution as _node_execution_models -from flytekit.models.core import execution as _execution_models -from flytekit.models.core import workflow as _workflow_model - - -class ParameterMapper(_SortedDict, metaclass=_common_models.FlyteABCMeta): - """ - This abstract class provides functionality to reference specific inputs and outputs for a task instance. This - allows for syntax such as: - - my_task_instance.inputs.my_input - - And is especially useful for linking tasks together via outputs -> inputs in workflow definitions: - - my_second_task_instance(input=my_task_instances.outputs.my_output) - - Attributes: - Dynamically discovered. Only the keys for inputs/outputs can be referenced. - - Example: - - .. code-block:: python - - @inputs(a=Types.Integer) - @outputs(b=Types.String) - @python_task(version='1') - def my_task(wf_params, a, b): - pass - - input_link = my_task.inputs.a # Success! - output_link = my_tasks.outputs.b # Success! - - input_link = my_task.inputs.c # Attribute not found exception! - output_link = my_task.outputs.d # Attribute not found exception! - - """ - - def __init__(self, type_map, node): - """ - :param dict[Text, flytekit.models.interface.Variable] type_map: - :param SdkNode node: - """ - super(ParameterMapper, self).__init__() - for key, var in _six.iteritems(type_map): - self[key] = self._return_mapping_object(node, _type_helpers.get_sdk_type_from_literal_type(var.type), key) - self._initialized = True - - def __getattr__(self, key): - if key == "iteritems" and hasattr(super(ParameterMapper, self), "items"): - return super(ParameterMapper, self).items - if hasattr(super(ParameterMapper, self), key): - return getattr(super(ParameterMapper, self), key) - if key not in self: - raise _user_exceptions.FlyteAssertion("{} doesn't exist.".format(key)) - return self[key] - - def __setattr__(self, key, value): - if "_initialized" in self.__dict__: - raise _user_exceptions.FlyteAssertion("Parameters are immutable.") - else: - super(ParameterMapper, self).__setattr__(key, value) - - @_abc.abstractmethod - def _return_mapping_object(self, sdk_node, sdk_type, name): - """ - :param flytekit.common.nodes.Node sdk_node: - :param flytekit.common.types.FlyteSdkType sdk_type: - :param Text name: - """ - pass - - -class OutputParameterMapper(ParameterMapper): - """ - This subclass of ParameterMapper is used to represent outputs for a given node. - """ - - def _return_mapping_object(self, sdk_node, sdk_type, name): - """ - :param flytekit.common.nodes.Node sdk_node: - :param flytekit.common.types.FlyteSdkType sdk_type: - :param Text name: - """ - return _promise.NodeOutput(sdk_node, sdk_type, name) - - -class SdkNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node, metaclass=_sdk_bases.ExtendedSdkType): - def __init__( - self, - id, - upstream_nodes, - bindings, - metadata, - sdk_task=None, - sdk_workflow=None, - sdk_launch_plan=None, - sdk_branch=None, - parameter_mapping=True, - ): - """ - :param Text id: A workflow-level unique identifier that identifies this node in the workflow. "inputs" and - "outputs" are reserved node ids that cannot be used by other nodes. - :param flytekit.models.core.workflow.NodeMetadata metadata: Extra metadata about the node. - :param list[flytekit.models.literals.Binding] bindings: Specifies how to bind the underlying - interface's inputs. All required inputs specified in the underlying interface must be fulfilled. - :param list[SdkNode] upstream_nodes: Specifies execution dependencies for this node ensuring it will - only get scheduled to run after all its upstream nodes have completed. This node will have - an implicit dependency on any node that appears in inputs field. - :param flytekit.common.tasks.task.SdkTask sdk_task: The task to execute in this - node. - :param flytekit.common.workflow.SdkWorkflow sdk_workflow: The workflow to execute in this node. - :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: The launch plan to execute in this - node. - :param TODO sdk_branch: TODO - """ - non_none_entities = [ - entity for entity in [sdk_workflow, sdk_branch, sdk_launch_plan, sdk_task] if entity is not None - ] - if len(non_none_entities) != 1: - raise _user_exceptions.FlyteAssertion( - "An SDK node must have one underlying entity specified at once. Received the following " - "entities: {}".format(non_none_entities) - ) - - workflow_node = None - if sdk_workflow is not None: - workflow_node = _component_nodes.SdkWorkflowNode(sdk_workflow=sdk_workflow) - elif sdk_launch_plan is not None: - workflow_node = _component_nodes.SdkWorkflowNode(sdk_launch_plan=sdk_launch_plan) - - # TODO: this calls the constructor which means it will set all the upstream node ids to None if at the time of - # this instantiation, the upstream nodes have not had their nodes assigned yet. - super(SdkNode, self).__init__( - id=_dnsify(id) if id else None, - metadata=metadata, - inputs=bindings, - upstream_node_ids=[n.id for n in upstream_nodes], - output_aliases=[], # TODO: Are aliases a thing in SDK nodes - task_node=_component_nodes.SdkTaskNode(sdk_task) if sdk_task else None, - workflow_node=workflow_node, - branch_node=sdk_branch, - ) - self._upstream = upstream_nodes - self._executable_sdk_object = sdk_task or sdk_workflow or sdk_launch_plan - if parameter_mapping: - if not sdk_branch: - self._outputs = OutputParameterMapper(self._executable_sdk_object.interface.outputs, self) - else: - self._outputs = None - - @property - def executable_sdk_object(self): - return self._executable_sdk_object - - @classmethod - def promote_from_model(cls, model, sub_workflows, tasks): - """ - :param flytekit.models.core.workflow.Node model: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.workflow.WorkflowTemplate] - sub_workflows: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: If specified, - these task templates will be passed to the SdkTaskNode promote_from_model call, and used - instead of fetching from Admin. - :rtype: SdkNode - """ - id = model.id - # This should never be called - if id == _constants.START_NODE_ID or id == _constants.END_NODE_ID: - _logging.warning("Should not call promote from model on a start node or end node {}".format(model)) - return None - - sdk_task_node, sdk_workflow_node = None, None - if model.task_node is not None: - sdk_task_node = _component_nodes.SdkTaskNode.promote_from_model(model.task_node, tasks) - elif model.workflow_node is not None: - sdk_workflow_node = _component_nodes.SdkWorkflowNode.promote_from_model( - model.workflow_node, sub_workflows, tasks - ) - else: - raise _system_exceptions.FlyteSystemException("Bad Node model, neither task nor workflow detected") - - # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a - # start node. In order to make the promoted SdkWorkflow look the same, we strip the start-node text back out. - for i in model.inputs: - if i.binding.promise is not None and i.binding.promise.node_id == _constants.START_NODE_ID: - i.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID - - if sdk_task_node is not None: - return cls( - id=id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - sdk_task=sdk_task_node.sdk_task, - ) - elif sdk_workflow_node is not None: - if sdk_workflow_node.sdk_workflow is not None: - return cls( - id=id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - sdk_workflow=sdk_workflow_node.sdk_workflow, - ) - elif sdk_workflow_node.sdk_launch_plan is not None: - return cls( - id=id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - sdk_launch_plan=sdk_workflow_node.sdk_launch_plan, - ) - else: - raise _system_exceptions.FlyteSystemException( - "Bad SdkWorkflowNode model, both lp and workflow are None" - ) - else: - raise _system_exceptions.FlyteSystemException("Bad SdkNode model, both task and workflow nodes are empty") - - @property - def upstream_nodes(self): - """ - :rtype: list[SdkNode] - """ - return self._upstream - - @property - def upstream_node_ids(self): - """ - :rtype: list[Text] - """ - return [n.id for n in sorted(self.upstream_nodes, key=lambda x: x.id)] - - @property - def outputs(self): - """ - :rtype: dict[Text, flytekit.common.promise.NodeOutput] - """ - return self._outputs - - def assign_id_and_return(self, id): - """ - :param Text id: - :rtype: None - """ - if self.id: - raise _user_exceptions.FlyteAssertion( - "Error assigning ID: {} because {} is already assigned. Has this node been assigned to another " - "workflow already?".format(id, self) - ) - self._id = _dnsify(id) if id else None - self._metadata._name = id - return self - - def with_overrides(self, *args, **kwargs): - # TODO: Implement overrides - raise NotImplementedError("Overrides are not supported in Flyte yet.") - - @_exception_scopes.system_entry_point - def __lshift__(self, other): - """ - Add a node upstream of this node without necessarily mapping outputs -> inputs. - :param Node other: node to place upstream - """ - if hash(other) not in set(hash(n) for n in self.upstream_nodes): - self._upstream.append(other) - return other - - @_exception_scopes.system_entry_point - def __rshift__(self, other): - """ - Add a node downstream of this node without necessarily mapping outputs -> inputs. - - :param Node other: node to place downstream - """ - if hash(self) not in set(hash(n) for n in other.upstream_nodes): - other.upstream_nodes.append(self) - return other - - def __repr__(self): - """ - :rtype: Text - """ - return "Node(ID: {} Executable: {})".format(self.id, self._executable_sdk_object) - - -class SdkNodeExecution( - _node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact, metaclass=_sdk_bases.ExtendedSdkType -): - def __init__(self, *args, **kwargs): - super(SdkNodeExecution, self).__init__(*args, **kwargs) - self._task_executions = None - self._workflow_executions = None - self._inputs = None - self._outputs = None - - @property - def task_executions(self): - """ - Returns the underlying task executions in order of try attempt. - :rtype: list[flytekit.common.tasks.executions.SdkTaskExecution] - """ - return self._task_executions or [] - - @property - def workflow_executions(self): - """ - Returns the underlying workflow executions in order of try attempt. - :rtype: list[flytekit.common.workflow_execution.SdkWorkflowExecution] - """ - return self._workflow_executions or [] - - @property - def executions(self): - """ - Returns a list of generic execution artifacts. - :rtype: list[flytekit.common.mixins.artifact.ExecutionArtifact] - """ - return self.task_executions or self.workflow_executions or [] - - @property - def inputs(self): - """ - Returns the inputs to the execution in the standard Python format as dictated by the type engine. - :rtype: dict[Text, T] - """ - if self._inputs is None: - client = _flyte_engine.get_client() - execution_data = client.get_node_execution_data(self.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_inputs.literals): - input_map = execution_data.full_inputs - elif execution_data.inputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) - input_map = _literal_models.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - else: - input_map = _literal_models.LiteralMap({}) - - self._inputs = _type_helpers.unpack_literal_map_to_sdk_python_std(input_map) - return self._inputs - - @property - def outputs(self): - """ - Returns the outputs to the execution in the standard Python format as dictated by the type engine. If the - execution ended in error or the execution is in progress, an exception will be raised. - :rtype: dict[Text, T] - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please what until the node execution has completed before requesting the outputs." - ) - if self.error: - raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") - - if self._outputs is None: - client = _flyte_engine.get_client() - execution_data = client.get_node_execution_data(self.id) - - # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_outputs.literals): - output_map = execution_data.full_outputs - - elif execution_data.outputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) - output_map = _literal_models.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - else: - output_map = _literal_models.LiteralMap({}) - - self._outputs = _type_helpers.unpack_literal_map_to_sdk_python_std(output_map) - return self._outputs - - @property - def error(self): - """ - If execution is in progress, raise an exception. Otherwise, return None if no error was present upon - reaching completion. - :rtype: flytekit.models.core.execution.ExecutionError or None - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please wait until the node execution has completed before requesting error information." - ) - return self.closure.error - - @property - def is_complete(self): - """ - Dictates whether or not the execution is complete. - :rtype: bool - """ - return self.closure.phase in { - _execution_models.NodeExecutionPhase.ABORTED, - _execution_models.NodeExecutionPhase.FAILED, - _execution_models.NodeExecutionPhase.SKIPPED, - _execution_models.NodeExecutionPhase.SUCCEEDED, - _execution_models.NodeExecutionPhase.TIMED_OUT, - } - - @classmethod - def promote_from_model(cls, base_model): - """ - :param _node_execution_models.NodeExecution base_model: - :rtype: SdkNodeExecution - """ - return cls( - closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri, metadata=base_model.metadata - ) - - def sync(self): - """ - Syncs the state of this object with that held by the platform. - :rtype: None - """ - if not self.is_complete or self.task_executions is not None: - client = _flyte_engine.get_client() - self._closure = client.get_node_execution(self.id).closure - task_executions = list(_iterate_task_executions(client, self.id)) - self._task_executions = [_task_executions.SdkTaskExecution.promote_from_model(te) for te in task_executions] - # TODO: Sub-workflows too once implemented - - def _sync_closure(self): - """ - Syncs the closure of the underlying execution artifact with the state observed by the platform. - :rtype: None - """ - client = _flyte_engine.get_client() - self._closure = client.get_node_execution(self.id).closure diff --git a/flytekit/common/notifications.py b/flytekit/common/notifications.py deleted file mode 100644 index 09b1d11358..0000000000 --- a/flytekit/common/notifications.py +++ /dev/null @@ -1,101 +0,0 @@ -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import common as _common_model -from flytekit.models.core import execution as _execution_model - - -class Notification(_common_model.Notification, metaclass=_sdk_bases.ExtendedSdkType): - - VALID_PHASES = { - _execution_model.WorkflowExecutionPhase.ABORTED, - _execution_model.WorkflowExecutionPhase.FAILED, - _execution_model.WorkflowExecutionPhase.SUCCEEDED, - _execution_model.WorkflowExecutionPhase.TIMED_OUT, - } - - def __init__(self, phases, email=None, pager_duty=None, slack=None): - """ - :param list[int] phases: A required list of phases for which to fire the event. Events can only be fired for - terminal phases. Phases should be as defined in: flytekit.models.core.execution.WorkflowExecutionPhase - """ - self._validate_phases(phases) - super(Notification, self).__init__(phases, email=email, pager_duty=pager_duty, slack=slack) - - def _validate_phases(self, phases): - """ - :param list[int] phases: - """ - if len(phases) == 0: - raise _user_exceptions.FlyteAssertion("You must specify at least one phase for a notification.") - for phase in phases: - if phase not in self.VALID_PHASES: - raise _user_exceptions.FlyteValueException( - phase, - self.VALID_PHASES, - additional_message="Notifications can only be specified on terminal states.", - ) - - @classmethod - def from_flyte_idl(cls, p): - """ - :param flyteidl.admin.common_pb2.Notification p: - :rtype: Notification - """ - if p.HasField("email"): - return cls(p.phases, p.email.recipients_email) - elif p.HasField("pager_duty"): - return cls(p.phases, p.pager_duty.recipients_email) - else: - return cls(p.phases, p.slack.recipients_email) - - -class PagerDuty(Notification): - def __init__(self, phases, recipients_email): - """ - :param list[Text] recipients_email: A required non-empty list of recipients for the notification. - """ - super(PagerDuty, self).__init__(phases, pager_duty=_common_model.PagerDutyNotification(recipients_email)) - - @classmethod - def promote_from_model(cls, base_model): - """ - :param flytekit.models.common.Notification base_model: - :rtype: Notification - """ - return cls(base_model.phases, base_model.pager_duty.recipients_email) - - -class Email(Notification): - def __init__(self, phases, recipients_email): - """ - :param list[Text] recipients_email: A required non-empty list of recipients for the notification. - :param list[int] phases: A required list of phases for which to fire the event. Events can only be fired for - terminal phases. Phases should be as defined in: flytekit.models.core.execution.WorkflowExecutionPhase - """ - super(Email, self).__init__(phases, email=_common_model.EmailNotification(recipients_email)) - - @classmethod - def promote_from_model(cls, base_model): - """ - :param flytekit.models.common.Notification base_model: - :rtype: Notification - """ - return cls(base_model.phases, base_model.email.recipients_email) - - -class Slack(Notification): - def __init__(self, phases, recipients_email): - """ - :param list[Text] recipients_email: A required non-empty list of recipients for the notification. - :param list[int] phases: A required list of phases for which to fire the event. Events can only be fired for - terminal phases. Phases should be as defined in: flytekit.models.core.execution.WorkflowExecutionPhase - """ - super(Slack, self).__init__(phases, slack=_common_model.SlackNotification(recipients_email)) - - @classmethod - def promote_from_model(cls, base_model): - """ - :param flytekit.models.common.Notification base_model: - :rtype: Notification - """ - return cls(base_model.phases, base_model.slack.recipients_email) diff --git a/flytekit/common/promise.py b/flytekit/common/promise.py deleted file mode 100644 index 79352dd638..0000000000 --- a/flytekit/common/promise.py +++ /dev/null @@ -1,169 +0,0 @@ -from flytekit.common import constants as _constants -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import helpers as _type_helpers -from flytekit.models import interface as _interface_models -from flytekit.models import types as _type_models - - -class Input(_interface_models.Parameter, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, name, sdk_type, help=None, **kwargs): - """ - :param Text name: - :param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: This is the SDK type necessary to create an - input to this workflow. - :param Text help: An optional help string to describe the input to users. - :param bool required: If set to True, default must be None - :param T default: If this is not a required input, the value will default to this value. - """ - param_default = None - if "required" not in kwargs and "default" not in kwargs: - # Neither required or default is set so assume required - required = True - default = None - elif kwargs.get("required", False) and "default" in kwargs: - # Required cannot be set to True and have a default specified - raise _user_exceptions.FlyteAssertion("Default cannot be set when required is True") - elif "default" in kwargs: - # If default is specified, then required must be false and the value is whatever is specified - required = None - default = kwargs["default"] - param_default = sdk_type.from_python_std(default) - else: - # If no default is set, but required is set, then the behavior is determined by required == True or False - default = None - required = kwargs["required"] - if not required: - # If required == False, we assume default to be None - param_default = sdk_type.from_python_std(default) - required = None - - self._sdk_required = required or False - self._sdk_default = default - self._help = help - self._sdk_type = sdk_type - self._promise = _type_models.OutputReference(_constants.GLOBAL_INPUT_NODE_ID, name) - self._name = name - super(Input, self).__init__( - _interface_models.Variable(type=sdk_type.to_flyte_literal_type(), description=help or ""), - required=required, - default=param_default, - ) - - def rename_and_return_reference(self, new_name): - self._promise._var = new_name - return self - - @property - def name(self): - """ - :rtype: Text - """ - return self._promise.var - - @property - def promise(self): - """ - :rtype: flytekit.models.types.OutputReference - """ - return self._promise - - @property - def sdk_required(self): - """ - :rtype: bool - """ - return self._sdk_required - - @property - def sdk_default(self): - """ - :rtype: T - """ - return self._sdk_default - - @property - def help(self): - """ - :rtype: Text - """ - return self._help - - @property - def sdk_type(self): - """ - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - return self._sdk_type - - def __repr__(self): - return "Input({}, {}, required={}, help={})".format(self.name, self.sdk_type, self.required, self.help) - - @classmethod - def promote_from_model(cls, model): - """ - :param flytekit.models.interface.Parameter model: - :rtype: Input - """ - sdk_type = _type_helpers.get_sdk_type_from_literal_type(model.var.type) - - if model.default is not None: - default_value = sdk_type.from_flyte_idl(model.default.to_flyte_idl()).to_python_std() - return cls( - "", - sdk_type, - help=model.var.description, - required=False, - default=default_value, - ) - else: - return cls("", sdk_type, help=model.var.description, required=True) - - -class NodeOutput(_type_models.OutputReference, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, sdk_node, sdk_type, var): - """ - :param sdk_node: - :param sdk_type: deprecated in mypy flytekit. - :param var: - """ - self._node = sdk_node - self._type = sdk_type - super(NodeOutput, self).__init__(self._node.id, var) - - @property - def node_id(self): - """ - Override the underlying node_id property to refer to SdkNode. - :rtype: Text - """ - return self.sdk_node.id - - @classmethod - def promote_from_model(cls, model): - """ - :param flytekit.models.types.OutputReference model: - :rtype: NodeOutput - """ - raise _user_exceptions.FlyteAssertion( - "A NodeOutput cannot be promoted from a protobuf because it must be " - "contextualized by an existing SdkNode." - ) - - @property - def sdk_node(self): - """ - :rtype: flytekit.common.nodes.SdkNode - """ - return self._node - - @property - def sdk_type(self): - """ - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - return self._type - - def __repr__(self): - s = f"NodeOutput({self.sdk_node if self.sdk_node.id is not None else None}:{self.var})" - return s diff --git a/flytekit/common/schedules.py b/flytekit/common/schedules.py deleted file mode 100644 index d6e71c6f83..0000000000 --- a/flytekit/common/schedules.py +++ /dev/null @@ -1,195 +0,0 @@ -import datetime as _datetime -import re as _re - -import croniter as _croniter - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import schedule as _schedule_models - - -class _ExtendedSchedule(_schedule_models.Schedule): - @classmethod - def from_flyte_idl(cls, proto): - """ - :param flyteidl.admin.schedule_pb2.Schedule proto: - :rtype: _ExtendedSchedule - """ - return cls.promote_from_model(_schedule_models.Schedule.from_flyte_idl(proto)) - - -class CronSchedule(_ExtendedSchedule, metaclass=_sdk_bases.ExtendedSdkType): - _VALID_CRON_ALIASES = [ - "hourly", - "hours", - "@hourly", - "daily", - "days", - "@daily", - "weekly", - "weeks", - "@weekly", - "monthly", - "months", - "@monthly", - "annually", - "@annually", - "yearly", - "years", - "@yearly", - ] - - # Not a perfect regex but good enough and simple to reason about - _OFFSET_PATTERN = _re.compile("([-+]?)P([-+0-9YMWD]+)?(T([-+0-9HMS.,]+)?)?") - - def __init__(self, cron_expression=None, schedule=None, offset=None, kickoff_time_input_arg=None): - """ - :param Text cron_expression: - :param Text schedule: - :param Text offset: - :param Text kickoff_time_input_arg: - """ - if cron_expression is None and schedule is None: - raise _user_exceptions.FlyteAssertion("Either `cron_expression` or `schedule` should be specified.") - - if cron_expression is not None and offset is not None: - raise _user_exceptions.FlyteAssertion("Only `schedule` is supported when specifying `offset`.") - - if cron_expression is not None: - CronSchedule._validate_expression(cron_expression) - - if schedule is not None: - CronSchedule._validate_schedule(schedule) - - if offset is not None: - CronSchedule._validate_offset(offset) - - super(CronSchedule, self).__init__( - kickoff_time_input_arg, - cron_expression=cron_expression, - cron_schedule=_schedule_models.Schedule.CronSchedule(schedule, offset) if schedule is not None else None, - ) - - @staticmethod - def _validate_expression(cron_expression): - """ - Ensures that the set value is a valid cron string. We use the format used in Cloudwatch and the best - explanation can be found here: - https://docs.aws.amazon.com/AmazonCloudWatch/latest/events/ScheduledEvents.html#CronExpressions - :param Text cron_expression: cron expression - """ - # We use the croniter lib to validate our cron expression. Since on the admin side we use Cloudwatch, - # we have a couple checks in order to line up Cloudwatch with Croniter. - tokens = cron_expression.split() - if len(tokens) != 6: - raise _user_exceptions.FlyteAssertion( - "Cron expression is invalid. A cron expression must have 6 fields. Cron expressions are in the " - "format of: `minute hour day-of-month month day-of-week year`. " - "Use `schedule` for 5 fields cron expression. Received: `{}`".format(cron_expression) - ) - - if tokens[2] != "?" and tokens[4] != "?": - raise _user_exceptions.FlyteAssertion( - "Scheduled string is invalid. A cron expression must have a '?' for either day-of-month or " - "day-of-week. Please specify '?' for one of those fields. Cron expressions are in the format of: " - "minute hour day-of-month month day-of-week year.\n\n" - "For more information: " - "https://docs.aws.amazon.com/AmazonCloudWatch/latest/events/ScheduledEvents.html#CronExpressions" - ) - - try: - # Cut to 5 fields and just assume year field is good because croniter treats the 6th field as seconds. - # TODO: Parse this field ourselves and check - _croniter.croniter(" ".join(cron_expression.replace("?", "*").split()[:5])) - except Exception: - raise _user_exceptions.FlyteAssertion( - "Scheduled string is invalid. The cron expression was found to be invalid." - " Provided cron expr: {}".format(cron_expression) - ) - - @staticmethod - def _validate_schedule(schedule): - if schedule.lower() not in CronSchedule._VALID_CRON_ALIASES: - try: - _croniter.croniter(schedule) - except Exception: - raise _user_exceptions.FlyteAssertion( - "Schedule is invalid. It must be set to either a cron alias or valid cron expression." - " Provided schedule: {}".format(schedule) - ) - - @staticmethod - def _validate_offset(offset): - if CronSchedule._OFFSET_PATTERN.fullmatch(offset) is None: - raise _user_exceptions.FlyteAssertion( - "Offset is invalid. It must be an ISO 8601 duration. Provided offset: {}".format(offset) - ) - - @classmethod - def promote_from_model(cls, base_model): - """ - :param flytekit.models.schedule.Schedule base_model: - :rtype: CronSchedule - """ - return cls( - cron_expression=base_model.cron_expression, - schedule=base_model.cron_schedule.schedule if base_model.cron_schedule is not None else None, - offset=base_model.cron_schedule.offset if base_model.cron_schedule is not None else None, - kickoff_time_input_arg=base_model.kickoff_time_input_arg, - ) - - -class FixedRate(_ExtendedSchedule, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, duration, kickoff_time_input_arg=None): - """ - :param datetime.timedelta duration: - :param Text kickoff_time_input_arg: - """ - super(FixedRate, self).__init__(kickoff_time_input_arg, rate=self._translate_duration(duration)) - - @staticmethod - def _translate_duration(duration): - """ - :param datetime.timedelta duration: timedelta between runs - :rtype: flytekit.models.schedule.Schedule.FixedRate - """ - _SECONDS_TO_MINUTES = 60 - _SECONDS_TO_HOURS = _SECONDS_TO_MINUTES * 60 - _SECONDS_TO_DAYS = _SECONDS_TO_HOURS * 24 - - if duration.microseconds != 0 or duration.seconds % _SECONDS_TO_MINUTES != 0: - raise _user_exceptions.FlyteAssertion( - "Granularity of less than a minute is not supported for FixedRate schedules. Received: {}".format( - duration - ) - ) - elif int(duration.total_seconds()) % _SECONDS_TO_DAYS == 0: - return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_DAYS), - _schedule_models.Schedule.FixedRateUnit.DAY, - ) - elif int(duration.total_seconds()) % _SECONDS_TO_HOURS == 0: - return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_HOURS), - _schedule_models.Schedule.FixedRateUnit.HOUR, - ) - else: - return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_MINUTES), - _schedule_models.Schedule.FixedRateUnit.MINUTE, - ) - - @classmethod - def promote_from_model(cls, base_model): - """ - :param flytekit.models.schedule.Schedule base_model: - :rtype: FixedRate - """ - if base_model.rate.unit == _schedule_models.Schedule.FixedRateUnit.DAY: - duration = _datetime.timedelta(days=base_model.rate.value) - elif base_model.rate.unit == _schedule_models.Schedule.FixedRateUnit.HOUR: - duration = _datetime.timedelta(hours=base_model.rate.value) - else: - duration = _datetime.timedelta(minutes=base_model.rate.value) - - return cls(duration, kickoff_time_input_arg=base_model.kickoff_time_input_arg) diff --git a/flytekit/common/sdk_bases.py b/flytekit/common/sdk_bases.py deleted file mode 100644 index d082302343..0000000000 --- a/flytekit/common/sdk_bases.py +++ /dev/null @@ -1,22 +0,0 @@ -import abc as _abc - -from flytekit.models import common as _common - - -class ExtendedSdkType(_common.FlyteType, metaclass=_common.FlyteABCMeta): - """ - Abstract class that all SDK objects must inherit from. This provides the ability to promote a data model object - into an actionable object. - """ - - @_abc.abstractmethod - def promote_from_model(cls, base_model): - """ - :param flytekit.models.common.FlyteIdlEntity base_model: - :rtype: ExtendedSdkType - """ - pass - - def from_flyte_idl(cls, pb2_object): - base_model = super(ExtendedSdkType, cls).from_flyte_idl(pb2_object) - return cls.promote_from_model(base_model) diff --git a/flytekit/common/tasks/__init__.py b/flytekit/common/tasks/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/common/tasks/executions.py b/flytekit/common/tasks/executions.py deleted file mode 100644 index d87c558c09..0000000000 --- a/flytekit/common/tasks/executions.py +++ /dev/null @@ -1,153 +0,0 @@ -import os as _os - -import six as _six -from flyteidl.core import literals_pb2 as _literals_pb2 - -from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import utils as _common_utils -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import artifact as _artifact_mixin -from flytekit.common.types import helpers as _type_helpers -from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literal_models -from flytekit.models.admin import task_execution as _task_execution_model -from flytekit.models.core import execution as _execution_models - - -class SdkTaskExecution( - _task_execution_model.TaskExecution, _artifact_mixin.ExecutionArtifact, metaclass=_sdk_bases.ExtendedSdkType -): - def __init__(self, *args, **kwargs): - super(SdkTaskExecution, self).__init__(*args, **kwargs) - self._inputs = None - self._outputs = None - - @property - def is_complete(self): - """ - Dictates whether or not the execution is complete. - :rtype: bool - """ - return self.closure.phase in { - _execution_models.TaskExecutionPhase.ABORTED, - _execution_models.TaskExecutionPhase.FAILED, - _execution_models.TaskExecutionPhase.SUCCEEDED, - } - - @property - def inputs(self): - """ - Returns the inputs of the task execution in the standard Python format that is produced by - the type engine. - :rtype: dict[Text, T] - """ - if self._inputs is None: - client = _flyte_engine.get_client() - execution_data = client.get_task_execution_data(self.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_inputs.literals): - input_map = execution_data.full_inputs - elif execution_data.inputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) - input_map = _literal_models.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - else: - input_map = _literal_models.LiteralMap({}) - - self._inputs = _type_helpers.unpack_literal_map_to_sdk_python_std(input_map) - return self._inputs - - @property - def outputs(self): - """ - Returns the outputs of the task execution, if available, in the standard Python format that is produced by - the type engine. If not available, perhaps due to execution being in progress or an error being produced, - this will raise an exception. - :rtype: dict[Text, T] - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please what until the task execution has completed before requesting the outputs." - ) - if self.error: - raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") - - if self._outputs is None: - client = _flyte_engine.get_client() - execution_data = client.get_task_execution_data(self.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_outputs.literals): - output_map = execution_data.full_outputs - - elif execution_data.outputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) - output_map = _literal_models.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - else: - output_map = _literal_models.LiteralMap({}) - self._outputs = _type_helpers.unpack_literal_map_to_sdk_python_std(output_map) - return self._outputs - - @property - def error(self): - """ - If execution is in progress, raise an exception. Otherwise, return None if no error was present upon - reaching completion. - :rtype: flytekit.models.core.execution.ExecutionError or None - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please what until the task execution has completed before requesting error information." - ) - return self.closure.error - - def get_child_executions(self, filters=None): - """ - :param list[flytekit.models.filters.Filter] filters: - :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] - """ - from flytekit.common import nodes as _nodes - - if not self.is_parent: - raise _user_exceptions.FlyteAssertion("Only task executions marked with 'is_parent' have child executions.") - client = _flyte_engine.get_client() - models = { - v.id.node_id: v - for v in _iterate_node_executions(client, task_execution_identifier=self.id, filters=filters) - } - - return {k: _nodes.SdkNodeExecution.promote_from_model(v) for k, v in _six.iteritems(models)} - - @classmethod - def promote_from_model(cls, base_model): - """ - :param _task_execution_model.TaskExecution base_model: - :rtype: SdkTaskExecution - """ - return cls( - closure=base_model.closure, - id=base_model.id, - input_uri=base_model.input_uri, - is_parent=base_model.is_parent, - ) - - def sync(self): - self._sync_closure() - - def _sync_closure(self): - """ - Syncs the closure of the underlying execution artifact with the state observed by the platform. - :rtype: None - """ - client = _flyte_engine.get_client() - self._closure = client.get_task_execution(self.id).closure diff --git a/flytekit/common/tasks/generic_spark_task.py b/flytekit/common/tasks/generic_spark_task.py deleted file mode 100644 index a83a7afbe4..0000000000 --- a/flytekit/common/tasks/generic_spark_task.py +++ /dev/null @@ -1,147 +0,0 @@ -import sys as _sys - -import six as _six -from google.protobuf.json_format import MessageToDict as _MessageToDict - -from flytekit import __version__ -from flytekit.common import interface as _interface -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import task as _base_tasks -from flytekit.common.types import helpers as _helpers -from flytekit.common.types import primitives as _primitives -from flytekit.configuration import internal as _internal_config -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models - -input_types_supported = { - _primitives.Integer, - _primitives.Boolean, - _primitives.Float, - _primitives.String, - _primitives.Datetime, - _primitives.Timedelta, -} - - -class SdkGenericSparkTask(_base_tasks.SdkTask): - """ - This class includes the additional logic for building a task that executes as a Spark Job. - - """ - - def __init__( - self, - task_type, - discovery_version, - retries, - interruptible, - task_inputs, - deprecated, - discoverable, - timeout, - spark_type, - main_class, - main_application_file, - spark_conf, - hadoop_conf, - environment, - ): - """ - :param Text task_type: string describing the task type - :param Text discovery_version: string describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param bool interruptible: Whether or not task is interruptible - :param Text deprecated: - :param bool discoverable: - :param datetime.timedelta timeout: - :param Text spark_type: Type of Spark Job: Scala/Java - :param Text main_class: Main class to execute for Scala/Java jobs - :param Text main_application_file: Main application file - :param dict[Text,Text] spark_conf: - :param dict[Text,Text] hadoop_conf: - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - """ - - spark_job = _task_models.SparkJob( - spark_conf=spark_conf, - hadoop_conf=hadoop_conf, - spark_type=spark_type, - application_file=main_application_file, - main_class=main_class, - executor_path=_sys.executable, - ).to_flyte_idl() - - super(SdkGenericSparkTask, self).__init__( - task_type, - _task_models.TaskMetadata( - discoverable, - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - "spark", - ), - timeout, - _literal_models.RetryStrategy(retries), - interruptible, - discovery_version, - deprecated, - ), - _interface.TypedInterface({}, {}), - _MessageToDict(spark_job), - ) - - # Add Inputs - if task_inputs is not None: - task_inputs(self) - - # Container after the Inputs have been updated. - self._container = self._get_container_definition(environment=environment) - - def _validate_inputs(self, inputs): - """ - :param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - for k, v in _six.iteritems(inputs): - sdk_type = _helpers.get_sdk_type_from_literal_type(v.type) - if sdk_type not in input_types_supported: - raise _user_exceptions.FlyteValidationException( - "Input Type '{}' not supported. Only Primitives are supported for Scala/Java Spark.".format( - sdk_type - ) - ) - super(SdkGenericSparkTask, self)._validate_inputs(inputs) - - @_exception_scopes.system_entry_point - def add_inputs(self, inputs): - """ - Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given - name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] inputs: names and variables - """ - self._validate_inputs(inputs) - self.interface.inputs.update(inputs) - - def _get_container_definition( - self, - environment=None, - ): - """ - :rtype: Container - """ - - args = [] - for k, v in _six.iteritems(self.interface.inputs): - args.append("--{}".format(k)) - args.append("{{{{.Inputs.{}}}}}".format(k)) - - return _task_models.Container( - image=_internal_config.IMAGE.get(), - command=[], - args=args, - resources=_task_models.Resources([], []), - env=environment, - config={}, - ) diff --git a/flytekit/common/tasks/hive_task.py b/flytekit/common/tasks/hive_task.py deleted file mode 100644 index 77ae3359a0..0000000000 --- a/flytekit/common/tasks/hive_task.py +++ /dev/null @@ -1,299 +0,0 @@ -import uuid as _uuid - -import six as _six -from google.protobuf.json_format import MessageToDict as _MessageToDict - -from flytekit.common import constants as _constants -from flytekit.common import interface as _interface -from flytekit.common import nodes as _nodes -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions.user import FlyteTypeException as _FlyteTypeException -from flytekit.common.exceptions.user import FlyteValueException as _FlyteValueException -from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.tasks import task as _base_task -from flytekit.common.types import helpers as _type_helpers -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literal_models -from flytekit.models import qubole as _qubole -from flytekit.models.core import workflow as _workflow_model - -ALLOWED_TAGS_COUNT = int(6) -MAX_TAG_LENGTH = int(20) - - -class SdkHiveTask(_sdk_runnable.SdkRunnableTask): - """ - This class includes the additional logic for building a task that executes as a batch hive task. - """ - - def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - cluster_label, - tags, - environment, - cache_serializable, - ): - """ - :param task_function: Function container user code. This will be executed via the SDK's engine. - :param Text task_type: string describing the task type - :param Text discovery_version: string describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param Text deprecated: - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - :param bool discoverable: - :param datetime.timedelta timeout: - :param Text cluster_label: - :param list[Text] tags: - :param dict[Text, Text] environment: - :param bool cache_serializable: - """ - self._task_function = task_function - super(SdkHiveTask, self).__init__( - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - cache_serializable, - {}, - ) - self._validate_task_parameters(cluster_label, tags) - self._cluster_label = cluster_label - self._tags = tags - - def _generate_plugin_objects(self, context, inputs_dict): - """ - Runs user code and and produces hive queries - :param flytekit.engines.common.EngineContext context: - :param dict[Text, T] inputs: - :rtype: list[_qubole.QuboleHiveJob] - """ - queries_from_task = super(SdkHiveTask, self)._execute_user_code(context, inputs_dict) or [] - if not isinstance(queries_from_task, list): - queries_from_task = [queries_from_task] - - self._validate_queries(queries_from_task) - plugin_objects = [] - - for q in queries_from_task: - hive_query = _qubole.HiveQuery( - query=q, - timeout_sec=self.metadata.timeout.seconds, - retry_count=self.metadata.retries.retries, - ) - - # TODO: Remove this after all users of older SDK versions that did the single node, multi-query pattern are - # deprecated. This is only here for backwards compatibility - in addition to writing the query to the - # query field, we also construct a QueryCollection with only one query. This will ensure that the - # older plugin will continue to work. - query_collection = _qubole.HiveQueryCollection([hive_query]) - - plugin_objects.append( - _qubole.QuboleHiveJob( - hive_query, - self._cluster_label, - self._tags, - query_collection=query_collection, - ) - ) - - return plugin_objects - - @staticmethod - def _validate_task_parameters(cluster_label, tags): - if not (cluster_label is None or isinstance(cluster_label, (str, _six.text_type))): - raise _FlyteTypeException( - type(cluster_label), - {str, _six.text_type}, - additional_msg="cluster_label for a hive task must be in text format", - received_value=cluster_label, - ) - if tags is not None: - if not (isinstance(tags, list) and all(isinstance(tag, (str, _six.text_type)) for tag in tags)): - raise _FlyteTypeException( - type(tags), - [], - additional_msg="tags for a hive task must be in 'list of text' format", - received_value=tags, - ) - if len(tags) > ALLOWED_TAGS_COUNT: - raise _FlyteValueException( - len(tags), - "number of tags must be less than {}".format(ALLOWED_TAGS_COUNT), - ) - if not all(len(tag) for tag in tags): - raise _FlyteValueException( - tags, - "length of a tag must be less than {} chars".format(MAX_TAG_LENGTH), - ) - - @staticmethod - def _validate_queries(queries_from_task): - for query_from_task in queries_from_task or []: - if not isinstance(query_from_task, (str, _six.text_type)): - raise _FlyteTypeException( - type(query_from_task), - {str, _six.text_type}, - additional_msg="All queries returned from a Hive task must be in text format.", - received_value=query_from_task, - ) - - def _produce_dynamic_job_spec(self, context, inputs): - """ - Runs user code and and produces future task nodes to run sub-tasks. - :param context: - :param flytekit.models.literals.LiteralMap literal_map inputs: - :rtype: flytekit.models.dynamic_job.DynamicJobSpec - """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( - inputs, - {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, - ) - outputs_dict = { - name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) - for name, variable in _six.iteritems(self.interface.outputs) - } - - # Add outputs to inputs - inputs_dict.update(outputs_dict) - - nodes = [] - tasks = [] - # One node per query - generated_queries = self._generate_plugin_objects(context, inputs_dict) - - # Create output bindings always - this has to happen after user code has run - output_bindings = [ - _literal_models.Binding( - var=name, - binding=_interface.BindingData.from_python_std(b.sdk_type.to_flyte_literal_type(), b.value), - ) - for name, b in _six.iteritems(outputs_dict) - ] - - i = 0 - for quboleHiveJob in generated_queries: - hive_job_node = _create_hive_job_node("HiveQuery_{}".format(i), quboleHiveJob.to_flyte_idl(), self.metadata) - nodes.append(hive_job_node) - tasks.append(hive_job_node.executable_sdk_object) - i += 1 - - dynamic_job_spec = _dynamic_job.DynamicJobSpec( - min_successes=len(nodes), - tasks=tasks, - nodes=nodes, - outputs=output_bindings, - subworkflows=[], - ) - - return dynamic_job_spec - - @_exception_scopes.system_entry_point - def execute(self, context, inputs): - """ - Executes hive batch task's user code and produces futures file as well as all sub-task inputs.pb files. - - :param flytekit.engines.common.EngineContext context: - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - spec = self._produce_dynamic_job_spec(context, inputs) - generated_files = {} - - # If no queries were produced, then the spec should not have any nodes, in which case we just produce an - # outputs file like any other single-step tasks. - if len(spec.nodes) == 0: - return { - _constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in spec.outputs} - ) - } - else: - generated_files.update({_constants.FUTURES_FILE_NAME: spec}) - - return generated_files - - -def _create_hive_job_node(name, hive_job, metadata): - """ - :param Text name: - :param _qubole.QuboleHiveJob hive_job: Hive job spec - :param flytekit.models.task.TaskMetadata metadata: This contains information needed at runtime to determine - behavior such as whether or not outputs are discoverable, timeouts, and retries. - :rtype: _nodes.SdkNode: - """ - return _nodes.SdkNode( - id=_six.text_type(_uuid.uuid4()), - upstream_nodes=[], - bindings=[], - metadata=_workflow_model.NodeMetadata(name, metadata.timeout, _literal_models.RetryStrategy(0)), - sdk_task=SdkHiveJob(hive_job, metadata), - ) - - -class SdkHiveJob(_base_task.SdkTask): - """ - This class encapsulates the hive-job that is submitted to the Qubole Operator. - - """ - - def __init__( - self, - hive_job, - metadata, - ): - """ - :param _qubole.QuboleHiveJob hive_job: Hive job spec - :param TaskMetadata metadata: This contains information needed at runtime to determine behavior such as - whether or not outputs are discoverable, timeouts, and retries. - """ - super(SdkHiveJob, self).__init__( - _constants.SdkTaskType.HIVE_JOB, - metadata, - # Individual hive tasks never take anything, or return anything. They just run a query that's already - # got the location set. - _interface_model.TypedInterface({}, {}), - _MessageToDict(hive_job), - ) diff --git a/flytekit/common/tasks/output.py b/flytekit/common/tasks/output.py deleted file mode 100644 index 8ca1b307cb..0000000000 --- a/flytekit/common/tasks/output.py +++ /dev/null @@ -1,46 +0,0 @@ -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.types import base_sdk_types as _base_sdk_types - - -class OutputReference(object): - def __init__(self, sdk_type): - """ - :param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: - """ - self._raw_value = None - self._sdk_type = sdk_type - self._sdk_value = _base_sdk_types.Void() - - @property - def value(self): - """ - :rtype: T - """ - return self._raw_value - - @property - def sdk_value(self): - """ - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkValue - """ - return self._sdk_value - - @property - def sdk_type(self): - """ - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - return self._sdk_type - - @_exception_scopes.system_entry_point - def set(self, value): - """ - This should be called to set the value for output. The SDK will apply the appropriate type and value checking. - It will raise an exception if necessary. - :param T value: - :raises: flytekit.common.exceptions.user.FlyteValueException - """ - - sdk_value = self._sdk_type.from_python_std(value) - self._raw_value = value - self._sdk_value = sdk_value diff --git a/flytekit/common/tasks/presto_task.py b/flytekit/common/tasks/presto_task.py deleted file mode 100644 index c8f7d300a5..0000000000 --- a/flytekit/common/tasks/presto_task.py +++ /dev/null @@ -1,180 +0,0 @@ -import datetime as _datetime - -import six as _six -from google.protobuf.json_format import MessageToDict as _MessageToDict - -from flytekit import __version__ -from flytekit.common import constants as _constants -from flytekit.common import interface as _interface -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.tasks import task as _base_task -from flytekit.common.types import helpers as _type_helpers -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literals -from flytekit.models import presto as _presto_models -from flytekit.models import task as _task_model -from flytekit.models import types as _types - - -class SdkPrestoTask(_base_task.SdkTask): - """ - This class includes the logic for building a task that executes as a Presto task. - """ - - def __init__( - self, - statement, - output_schema, - routing_group=None, - catalog=None, - schema=None, - task_inputs=None, - interruptible=False, - discoverable=False, - discovery_version=None, - retries=1, - timeout=None, - deprecated=None, - cache_serializable=False, - ): - """ - :param Text statement: Presto query specification - :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto - :param Text routing_group: The routing group that a Presto query should be sent to for the given environment - :param Text catalog: The catalog to set for the given Presto query - :param Text schema: The schema to set for the given Presto query - :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task - :param bool discoverable: - :param Text discovery_version: String describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param datetime.timedelta timeout: - :param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will - receive deprecation warnings. - :param bool cache_serializable: - """ - - # Set as class fields which are used down below to configure implicit - # parameters - self._routing_group = routing_group or "" - self._catalog = catalog or "" - self._schema = schema or "" - - metadata = _task_model.TaskMetadata( - discoverable, - # This needs to have the proper version reflected in it - _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), - timeout or _datetime.timedelta(seconds=0), - _literals.RetryStrategy(retries), - interruptible, - discovery_version, - deprecated, - cache_serializable, - ) - - presto_query = _presto_models.PrestoQuery( - routing_group=routing_group or "", - catalog=catalog or "", - schema=schema or "", - statement=statement, - ) - - # Here we set the routing_group, catalog, and schema as implicit - # parameters for caching purposes - i = _interface.TypedInterface( - { - "__implicit_routing_group": _interface_model.Variable( - type=_types.LiteralType(simple=_types.SimpleType.STRING), - description="The routing group set as an implicit input", - ), - "__implicit_catalog": _interface_model.Variable( - type=_types.LiteralType(simple=_types.SimpleType.STRING), - description="The catalog set as an implicit input", - ), - "__implicit_schema": _interface_model.Variable( - type=_types.LiteralType(simple=_types.SimpleType.STRING), - description="The schema set as an implicit input", - ), - }, - { - # Set the schema for the Presto query as an output - "results": _interface_model.Variable( - type=_types.LiteralType(schema=output_schema.schema_type), - description="The schema for the Presto query", - ) - }, - ) - - super(SdkPrestoTask, self).__init__( - _constants.SdkTaskType.PRESTO_TASK, - metadata, - i, - _MessageToDict(presto_query.to_flyte_idl()), - ) - - # Set user provided inputs - task_inputs(self) - - def _add_implicit_inputs(self, inputs): - """ - :param dict[Text,Any] inputs: - :param inputs: - :return: - """ - inputs["__implicit_routing_group"] = self.routing_group - inputs["__implicit_catalog"] = self.catalog - inputs["__implicit_schema"] = self.schema - return inputs - - # Override method in order to set the implicit inputs - def __call__(self, *args, **kwargs): - kwargs = self._add_implicit_inputs(kwargs) - - return super(SdkPrestoTask, self).__call__(*args, **kwargs) - - # Override method in order to set the implicit inputs - def _python_std_input_map_to_literal_map(self, inputs): - """ - :param dict[Text,Any] inputs: A dictionary of Python standard inputs that will be type-checked and compiled - to a LiteralMap - :rtype: flytekit.models.literals.LiteralMap - """ - inputs = self._add_implicit_inputs(inputs) - return _type_helpers.pack_python_std_map_to_literal_map( - inputs, - {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, - ) - - @_exception_scopes.system_entry_point - def add_inputs(self, inputs): - """ - Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given - name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] inputs: names and variables - """ - self._validate_inputs(inputs) - self.interface.inputs.update(inputs) - - @property - def routing_group(self): - """ - The routing group that a Presto query should be sent to for the given environment - :rtype: Text - """ - return self._routing_group - - @property - def catalog(self): - """ - The catalog to set for the given Presto query - :rtype: Text - """ - return self._catalog - - @property - def schema(self): - """ - The schema to set for the given Presto query - :rtype: Text - """ - return self._schema diff --git a/flytekit/common/tasks/raw_container.py b/flytekit/common/tasks/raw_container.py deleted file mode 100644 index 5168e744f3..0000000000 --- a/flytekit/common/tasks/raw_container.py +++ /dev/null @@ -1,237 +0,0 @@ -import datetime as _datetime -from typing import Dict, List - -import flytekit -from flytekit.common import constants as _constants -from flytekit.common import interface as _interface -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.tasks import task as _base_task -from flytekit.common.types.base_sdk_types import FlyteSdkType -from flytekit.configuration import resources as _resource_config -from flytekit.models import literals as _literals -from flytekit.models import task as _task_models -from flytekit.models.interface import Variable - - -def types_to_variable(t: Dict[str, FlyteSdkType]) -> Dict[str, Variable]: - var = {} - if t: - for k, v in t.items(): - var[k] = Variable(v.to_flyte_literal_type(), "") - return var - - -def _get_container_definition( - image: str, - command: List[str], - args: List[str], - data_loading_config: _task_models.DataLoadingConfig, - storage_request: str = None, - ephemeral_storage_request: str = None, - cpu_request: str = None, - gpu_request: str = None, - memory_request: str = None, - storage_limit: str = None, - ephemeral_storage_limit: str = None, - cpu_limit: str = None, - gpu_limit: str = None, - memory_limit: str = None, - environment: Dict[str, str] = None, -) -> _task_models.Container: - storage_limit = storage_limit or _resource_config.DEFAULT_STORAGE_LIMIT.get() - storage_request = storage_request or _resource_config.DEFAULT_STORAGE_REQUEST.get() - ephemeral_storage_limit = ephemeral_storage_limit or _resource_config.DEFAULT_EPHEMERAL_STORAGE_LIMIT.get() - ephemeral_storage_request = ephemeral_storage_request or _resource_config.DEFAULT_EPHEMERAL_STORAGE_REQUEST.get() - cpu_limit = cpu_limit or _resource_config.DEFAULT_CPU_LIMIT.get() - cpu_request = cpu_request or _resource_config.DEFAULT_CPU_REQUEST.get() - gpu_limit = gpu_limit or _resource_config.DEFAULT_GPU_LIMIT.get() - gpu_request = gpu_request or _resource_config.DEFAULT_GPU_REQUEST.get() - memory_limit = memory_limit or _resource_config.DEFAULT_MEMORY_LIMIT.get() - memory_request = memory_request or _resource_config.DEFAULT_MEMORY_REQUEST.get() - - requests = [] - if storage_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) - ) - if ephemeral_storage_request: - requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request - ) - ) - if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) - if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) - if memory_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) - ) - - limits = [] - if storage_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) - if ephemeral_storage_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit - ) - ) - if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) - if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) - if memory_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) - - if environment is None: - environment = {} - - return _task_models.Container( - image=image, - command=command, - args=args, - resources=_task_models.Resources(limits=limits, requests=requests), - env=environment, - config={}, - data_loading_config=data_loading_config, - ) - - -class SdkRawContainerTask(_base_task.SdkTask): - """ - Use this task when you want to run an arbitrary container as a task (e.g. external tools, binaries compiled - separately as a container completely separate from the container where your Flyte workflow is defined. - """ - - METADATA_FORMAT_JSON = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_JSON - METADATA_FORMAT_YAML = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_YAML - METADATA_FORMAT_PROTO = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_PROTO - - def __init__( - self, - inputs: Dict[str, FlyteSdkType], - image: str, - outputs: Dict[str, FlyteSdkType] = None, - input_data_dir: str = None, - output_data_dir: str = None, - metadata_format: int = METADATA_FORMAT_JSON, - io_strategy: _task_models.IOStrategy = None, - command: List[str] = None, - args: List[str] = None, - storage_request: str = None, - cpu_request: str = None, - gpu_request: str = None, - memory_request: str = None, - storage_limit: str = None, - cpu_limit: str = None, - gpu_limit: str = None, - memory_limit: str = None, - environment: Dict[str, str] = None, - interruptible: bool = False, - discoverable: bool = False, - discovery_version: str = None, - retries: int = 1, - timeout: _datetime.timedelta = None, - cache_serializable: bool = False, - ): - """ - :param inputs: - :param outputs: - :param image: - :param command: - :param args: - :param storage_request: - :param cpu_request: - :param gpu_request: - :param memory_request: - :param storage_limit: - :param cpu_limit: - :param gpu_limit: - :param memory_limit: - :param environment: - :param interruptible: - :param discoverable: - :param discovery_version: - :param retries: - :param timeout: - :param cache_serializable: - :param input_data_dir: This is the directory where data will be downloaded to - :param output_data_dir: This is the directory where data will be uploaded from - :param metadata_format: Format in which the metadata will be available for the script - """ - - # Set as class fields which are used down below to configure implicit - # parameters - self._data_loading_config = _task_models.DataLoadingConfig( - input_path=input_data_dir, - output_path=output_data_dir, - format=metadata_format, - enabled=True, - io_strategy=io_strategy, - ) - - metadata = _task_models.TaskMetadata( - discoverable, - # This needs to have the proper version reflected in it - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - flytekit.__version__, - "python", - ), - timeout or _datetime.timedelta(seconds=0), - _literals.RetryStrategy(retries), - interruptible, - discovery_version, - None, - cache_serializable, - ) - - # The interface is defined using the inputs and outputs - i = _interface.TypedInterface(inputs=types_to_variable(inputs), outputs=types_to_variable(outputs)) - - # This sets the base SDKTask with container etc - super(SdkRawContainerTask, self).__init__( - _constants.SdkTaskType.RAW_CONTAINER_TASK, - metadata, - i, - None, - container=_get_container_definition( - image=image, - args=args, - command=command, - data_loading_config=self._data_loading_config, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - environment=environment, - ), - ) - - @_exception_scopes.system_entry_point - def add_inputs(self, inputs: Dict[str, Variable]): - """ - Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given - name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] inputs: names and variables - """ - self._validate_inputs(inputs) - self.interface.inputs.update(inputs) - - @_exception_scopes.system_entry_point - def add_outputs(self, outputs: Dict[str, Variable]): - """ - Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given - name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] outputs: names and variables - """ - self._validate_outputs(outputs) - self.interface.outputs.update(outputs) diff --git a/flytekit/common/tasks/sdk_dynamic.py b/flytekit/common/tasks/sdk_dynamic.py deleted file mode 100644 index b3bc923601..0000000000 --- a/flytekit/common/tasks/sdk_dynamic.py +++ /dev/null @@ -1,372 +0,0 @@ -import itertools as _itertools -import math -import os as _os - -import six as _six - -from flytekit.common import constants as _constants -from flytekit.common import interface as _interface -from flytekit.common import launch_plan as _launch_plan -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import workflow as _workflow -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.mixins import registerable as _registerable -from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.tasks import task as _task -from flytekit.common.types import helpers as _type_helpers -from flytekit.common.utils import _dnsify -from flytekit.configuration import internal as _internal_config -from flytekit.models import array_job as _array_job -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import literals as _literal_models - - -class PromiseOutputReference(_task_output.OutputReference): - @property - def raw_value(self): - """ - :rtype: T - """ - return self._raw_value - - @_exception_scopes.system_entry_point - def set(self, value): - """ - This should be called to set the value for output. The SDK will apply the appropriate type and value checking. - It will raise an exception if necessary. - :param T value: - :raises: flytekit.common.exceptions.user.FlyteValueException - """ - - self._raw_value = value - - -def _append_node(generated_files, node, nodes, sub_task_node): - nodes.append(node) - for k, node_output in _six.iteritems(sub_task_node.outputs): - if not node_output.sdk_node.id: - node_output.sdk_node.assign_id_and_return(node.id) - - # Upload inputs to working directory under /array_job.input_ref/inputs.pb - input_path = _os.path.join(node.id, _constants.INPUT_FILE_NAME) - generated_files[input_path] = _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in sub_task_node.inputs} - ) - - -class SdkDynamicTaskMixin(object): - - """ - This mixin implements logic for building a task that executes - parent-child tasks in Python code. - - """ - - def __init__(self, allowed_failure_ratio, max_concurrency): - """ - :param float allowed_failure_ratio: - :param int max_concurrency: - """ - - # These will only appear in the generated futures - self._allowed_failure_ratio = allowed_failure_ratio - self._max_concurrency = max_concurrency - - def _create_array_job(self, inputs_prefix): - """ - Creates an array job for the passed sdk_task. - :param str inputs_prefix: - :rtype: _array_job.ArrayJob - """ - return _array_job.ArrayJob( - parallelism=self._max_concurrency if self._max_concurrency else 0, - size=1, - min_successes=1, - ) - - @staticmethod - def _can_run_as_array(task_type): - """ - Checks if a task can be grouped to run as an array job. - :param Text task_type: - :rtype: bool - """ - return task_type == _constants.SdkTaskType.PYTHON_TASK - - @staticmethod - def _add_upstream_entities(executable_sdk_object, sub_workflows, tasks): - upstream_entities = [] - if isinstance(executable_sdk_object, _workflow.SdkWorkflow): - upstream_entities = [n.executable_sdk_object for n in executable_sdk_object.nodes] - - for upstream_entity in upstream_entities: - # If the upstream entity is either a Workflow or a Task, yield them in the - # dynamic job spec. Otherwise (e.g. a LaunchPlan), we will assume it already - # is registered (can't be dynamically created). This will cause a runtime error - # if it's not already registered with the control plane. - if isinstance(upstream_entity, _workflow.SdkWorkflow): - sub_workflows.add(upstream_entity) - # Recursively discover all statically defined dependencies - SdkDynamicTask._add_upstream_entities(upstream_entity, sub_workflows, tasks) - elif isinstance(upstream_entity, _task.SdkTask): - tasks.add(upstream_entity) - - def _produce_dynamic_job_spec(self, context, inputs): - """ - Runs user code and and produces future task nodes to run sub-tasks. - :param context: - :param flytekit.models.literals.LiteralMap literal_map inputs: - :rtype: (_dynamic_job.DynamicJobSpec, dict[Text, flytekit.models.common.FlyteIdlEntity]) - """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( - inputs, - {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, - ) - outputs_dict = { - name: PromiseOutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) - for name, variable in _six.iteritems(self.interface.outputs) - } - - # Because users declare both inputs and outputs in their functions signatures, merge them together - # before calling user code - inputs_dict.update(outputs_dict) - yielded_sub_tasks = [sub_task for sub_task in self._execute_user_code(context, inputs_dict) or []] - - upstream_nodes = list() - output_bindings = [ - _literal_models.Binding( - var=name, - binding=_interface.BindingData.from_python_std( - b.sdk_type.to_flyte_literal_type(), - b.raw_value, - upstream_nodes=upstream_nodes, - ), - ) - for name, b in _six.iteritems(outputs_dict) - ] - upstream_nodes = set(upstream_nodes) - - generated_files = {} - # Keeping future-tasks in original order. We don't use upstream_nodes exclusively because the parent task can - # yield sub-tasks that it never uses to produce final outputs but they need to execute nevertheless. - array_job_index = {} - tasks = set() - nodes = [] - sub_workflows = set() - visited_nodes = set() - generated_ids = {} - effective_failure_ratio = self._allowed_failure_ratio or 0.0 - - # TODO: This function needs to be cleaned up. - # The reason we chain these two together is because we allow users to not have to explicitly "yield" the - # node. As long as the subtask/lp/subwf has an output that's referenced, it'll get picked up. - for sub_task_node in _itertools.chain(yielded_sub_tasks, upstream_nodes): - if sub_task_node in visited_nodes: - continue - visited_nodes.add(sub_task_node) - executable = sub_task_node.executable_sdk_object - - # If the executable object that we're dealing with is registerable (ie, SdkRunnableLaunchPlan, SdkWorkflow - # SdkTask, or SdkRunnableTask), then it should have the ability to give itself a name. After assigning - # itself the name, also make sure the id is properly set according to current config values. - if isinstance(executable, _registerable.TrackableEntity) and not executable.has_valid_name: - executable.auto_assign_name() - executable._id = _identifier.Identifier( - executable.resource_type, - _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), - _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), - executable.platform_valid_name, - _internal_config.TASK_VERSION.get() or _internal_config.VERSION.get(), - ) - - # Generate an id that's unique in the document (if the same task is used multiple times with - # different resources, executable_sdk_object.id will be the same but generated node_ids should not - # be. - safe_task_id = _six.text_type(sub_task_node.executable_sdk_object.id) - if safe_task_id in generated_ids: - new_count = generated_ids[safe_task_id] = generated_ids[safe_task_id] + 1 - else: - new_count = generated_ids[safe_task_id] = 0 - unique_node_id = _dnsify("{}-{}".format(safe_task_id, new_count)) - - # Handling case where the yielded node is launch plan - if isinstance(sub_task_node.executable_sdk_object, _launch_plan.SdkLaunchPlan): - node = sub_task_node.assign_id_and_return(unique_node_id) - _append_node(generated_files, node, nodes, sub_task_node) - # Handling case where the yielded node is launching a sub-workflow - elif isinstance(sub_task_node.executable_sdk_object, _workflow.SdkWorkflow): - node = sub_task_node.assign_id_and_return(unique_node_id) - _append_node(generated_files, node, nodes, sub_task_node) - # Add the workflow itself to the yielded sub-workflows - sub_workflows.add(sub_task_node.executable_sdk_object) - # Recursively discover statically defined upstream entities (tasks, wfs) - SdkDynamicTask._add_upstream_entities(sub_task_node.executable_sdk_object, sub_workflows, tasks) - # Handling tasks - else: - # If the task can run as an array job, group its instances together. Otherwise, keep each - # invocation as a separate node. - if SdkDynamicTask._can_run_as_array(sub_task_node.executable_sdk_object.type): - if sub_task_node.executable_sdk_object in array_job_index: - array_job, node = array_job_index[sub_task_node.executable_sdk_object] - array_job.size += 1 - array_job.min_successes = int(math.ceil((1 - effective_failure_ratio) * array_job.size)) - else: - array_job = self._create_array_job(inputs_prefix=unique_node_id) - node = sub_task_node.assign_id_and_return(unique_node_id) - array_job_index[sub_task_node.executable_sdk_object] = ( - array_job, - node, - ) - - node_index = _six.text_type(array_job.size - 1) - for k, node_output in _six.iteritems(sub_task_node.outputs): - if not node_output.sdk_node.id: - node_output.sdk_node.assign_id_and_return(node.id) - node_output.var = "[{}].{}".format(node_index, node_output.var) - - # Upload inputs to working directory under /array_job.input_ref//inputs.pb - input_path = _os.path.join(node.id, node_index, _constants.INPUT_FILE_NAME) - generated_files[input_path] = _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in sub_task_node.inputs} - ) - else: - node = sub_task_node.assign_id_and_return(unique_node_id) - tasks.add(sub_task_node.executable_sdk_object) - _append_node(generated_files, node, nodes, sub_task_node) - - # assign custom field to the ArrayJob properties computed. - for task, (array_job, _) in _six.iteritems(array_job_index): - # TODO: Reconstruct task template object instead of modifying an existing one? - tasks.add( - task.assign_custom_and_return(array_job.to_dict()).assign_type_and_return( - _constants.SdkTaskType.CONTAINER_ARRAY_TASK - ) - ) - - # min_successes is absolute, it's computed as the reverse of allowed_failure_ratio and multiplied by the - # total length of tasks to get an absolute count. - nodes.extend([array_job_node for (_, array_job_node) in array_job_index.values()]) - dynamic_job_spec = _dynamic_job.DynamicJobSpec( - min_successes=len(nodes), - tasks=list(tasks), - nodes=nodes, - outputs=output_bindings, - subworkflows=list(sub_workflows), - ) - - return dynamic_job_spec, generated_files - - @_exception_scopes.system_entry_point - def execute(self, context, inputs): - """ - Executes batch task's user code and produces futures file as well as all sub-task inputs.pb files. - - :param flytekit.engines.common.EngineContext context: - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - spec, generated_files = self._produce_dynamic_job_spec(context, inputs) - - # If no sub-tasks are requested to run, just produce an outputs file like any other single-step tasks. - if len(spec.nodes) == 0: - return { - _constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in spec.outputs} - ) - } - else: - generated_files.update({_constants.FUTURES_FILE_NAME: spec}) - - return generated_files - - -class SdkDynamicTask( - SdkDynamicTaskMixin, - _sdk_runnable.SdkRunnableTask, - metaclass=_sdk_bases.ExtendedSdkType, -): - - """ - This class includes the additional logic for building a task that executes - parent-child tasks in Python code. - - """ - - def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - allowed_failure_ratio, - max_concurrency, - environment, - cache_serializable, - custom, - ): - """ - :param task_function: Function container user code. This will be executed via the SDK's engine. - :param Text task_type: string describing the task type - :param Text discovery_version: string describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param bool interruptible: Whether or not task is interruptible - :param Text deprecated: - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - :param bool discoverable: - :param datetime.timedelta timeout: - :param float allowed_failure_ratio: - :param int max_concurrency: - :param dict[Text, Text] environment: - :param bool cache_serializable: - :param dict[Text, T] custom: - """ - _sdk_runnable.SdkRunnableTask.__init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - cache_serializable, - custom, - ) - - SdkDynamicTaskMixin.__init__(self, allowed_failure_ratio, max_concurrency) diff --git a/flytekit/common/tasks/sdk_runnable.py b/flytekit/common/tasks/sdk_runnable.py deleted file mode 100644 index 39788437dd..0000000000 --- a/flytekit/common/tasks/sdk_runnable.py +++ /dev/null @@ -1,750 +0,0 @@ -from __future__ import annotations - -import copy as _copy -import enum -import logging as _logging -import os -import pathlib -import typing -from dataclasses import dataclass -from datetime import datetime -from inspect import getfullargspec as _getargspec - -import six as _six - -from flytekit.common import constants as _constants -from flytekit.common import interface as _interface -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import utils as _common_utils -from flytekit.common.core.identifier import WorkflowExecutionIdentifier -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import task as _base_task -from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import internal as _internal_config -from flytekit.configuration import resources as _resource_config -from flytekit.configuration import sdk as _sdk_config -from flytekit.configuration import secrets -from flytekit.engines import loader as _engine_loader -from flytekit.interfaces.stats import taggable -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models - - -class SecretsManager(object): - """ - This provides a secrets resolution logic at runtime. - The resolution order is - - Try env var first. The env var should have the configuration.SECRETS_ENV_PREFIX. The env var will be all upper - cased - - If not then try the file where the name matches lower case - ``configuration.SECRETS_DEFAULT_DIR//configuration.SECRETS_FILE_PREFIX`` - - All configuration values can always be overridden by injecting an environment variable - """ - - def __init__(self): - self._base_dir = str(secrets.SECRETS_DEFAULT_DIR.get()).strip() - self._file_prefix = str(secrets.SECRETS_FILE_PREFIX.get()).strip() - self._env_prefix = str(secrets.SECRETS_ENV_PREFIX.get()).strip() - - def get(self, group: str, key: str) -> str: - """ - Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError - """ - self.check_group_key(group, key) - env_var = self.get_secrets_env_var(group, key) - fpath = self.get_secrets_file(group, key) - v = os.environ.get(env_var) - if v is not None: - return v - if os.path.exists(fpath): - with open(fpath, "r") as f: - return f.read().strip() - raise ValueError( - f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}" - ) - - def get_secrets_env_var(self, group: str, key: str) -> str: - """ - Returns a string that matches the ENV Variable to look for the secrets - """ - self.check_group_key(group, key) - return f"{self._env_prefix}{group.upper()}_{key.upper()}" - - def get_secrets_file(self, group: str, key: str) -> str: - """ - Returns a path that matches the file to look for the secrets - """ - self.check_group_key(group, key) - return os.path.join(self._base_dir, group.lower(), f"{self._file_prefix}{key.lower()}") - - @staticmethod - def check_group_key(group: str, key: str): - if group is None or group == "": - raise ValueError("secrets group is a mandatory field.") - if key is None or key == "": - raise ValueError("secrets key is a mandatory field.") - - -# TODO: Clean up working dir name -class ExecutionParameters(object): - """ - This is a run-time user-centric context object that is accessible to every @task method. It can be accessed using - - .. code-block:: python - - flytekit.current_context() - - This object provides the following - * a statsd handler - * a logging handler - * the execution ID as an :py:class:`flytekit.models.core.identifier.WorkflowExecutionIdentifier` object - * a working directory for the user to write arbitrary files to - - Please do not confuse this object with the :py:class:`flytekit.FlyteContext` object. - """ - - @dataclass(init=False) - class Builder(object): - stats: taggable.TaggableStats - execution_date: datetime - logging: _logging - execution_id: str - attrs: typing.Dict[str, typing.Any] - working_dir: typing.Union[os.PathLike, _common_utils.AutoDeletingTempDir] - - def __init__(self, current: typing.Optional[ExecutionParameters] = None): - self.stats = current.stats if current else None - self.execution_date = current.execution_date if current else None - self.working_dir = current.working_directory if current else None - self.execution_id = current.execution_id if current else None - self.logging = current.logging if current else None - self.attrs = current._attrs if current else {} - - def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: - self.attrs[key] = v - return self - - def build(self) -> ExecutionParameters: - if not isinstance(self.working_dir, _common_utils.AutoDeletingTempDir): - pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True) - return ExecutionParameters( - execution_date=self.execution_date, - stats=self.stats, - tmp_dir=self.working_dir, - execution_id=self.execution_id, - logging=self.logging, - **self.attrs, - ) - - @staticmethod - def new_builder(current: ExecutionParameters = None) -> Builder: - return ExecutionParameters.Builder(current=current) - - def builder(self) -> Builder: - return ExecutionParameters.Builder(current=self) - - def __init__(self, execution_date, tmp_dir, stats, execution_id, logging, **kwargs): - """ - Args: - execution_date: Date when the execution is running - tmp_dir: temporary directory for the execution - stats: handle to emit stats - execution_id: Identifier for the xecution - logging: handle to logging - """ - self._stats = stats - self._execution_date = execution_date - self._working_directory = tmp_dir - self._execution_id = execution_id - self._logging = logging - # AutoDeletingTempDir's should be used with a with block, which creates upon entry - self._attrs = kwargs - # It is safe to recreate the Secrets Manager - self._secrets_manager = SecretsManager() - - @property - def stats(self) -> taggable.TaggableStats: - """ - A handle to a special statsd object that provides usefully tagged stats. - TODO: Usage examples and better comments - """ - return self._stats - - @property - def logging(self) -> _logging: - """ - A handle to a useful logging object. - TODO: Usage examples - """ - return self._logging - - @property - def working_directory(self) -> _common_utils.AutoDeletingTempDir: - """ - A handle to a special working directory for easily producing temporary files. - - TODO: Usage examples - TODO: This does not always return a AutoDeletingTempDir - """ - return self._working_directory - - @property - def execution_date(self) -> datetime: - """ - This is a datetime representing the time at which a workflow was started. This is consistent across all tasks - executed in a workflow or sub-workflow. - - .. note:: - - Do NOT use this execution_date to drive any production logic. It might be useful as a tag for data to help - in debugging. - """ - return self._execution_date - - @property - def execution_id(self) -> str: - """ - This is the identifier of the workflow execution within the underlying engine. It will be consistent across all - task executions in a workflow or sub-workflow execution. - - .. note:: - - Do NOT use this execution_id to drive any production logic. This execution ID should only be used as a tag - on output data to link back to the workflow run that created it. - """ - return self._execution_id - - @property - def secrets(self) -> SecretsManager: - return self._secrets_manager - - def __getattr__(self, attr_name: str) -> typing.Any: - """ - This houses certain task specific context. For example in Spark, it houses the SparkSession, etc - """ - attr_name = attr_name.upper() - if self._attrs and attr_name in self._attrs: - return self._attrs[attr_name] - raise AssertionError(f"{attr_name} not available as a parameter in Flyte context - are you in right task-type?") - - def has_attr(self, attr_name: str) -> bool: - attr_name = attr_name.upper() - if self._attrs and attr_name in self._attrs: - return True - return False - - def get(self, key: str) -> typing.Any: - """ - Returns task specific context if present else raise an error. The returned context will match the key - """ - return self.__getattr__(attr_name=key) - - -class SdkRunnableContainer(_task_models.Container, metaclass=_sdk_bases.ExtendedSdkType): - """ - This is not necessarily a local-only Container object. So long as configuration is present, you can use this object - """ - - def __init__( - self, - command, - args, - resources, - env, - config, - ): - super(SdkRunnableContainer, self).__init__("", command, args, resources, env or {}, config) - - @property - def args(self): - """ - :rtype: list[Text] - """ - return _sdk_config.SDK_PYTHON_VENV.get() + self._args - - @property - def image(self): - """ - :rtype: Text - """ - return _internal_config.IMAGE.get() - - @property - def env(self): - """ - :rtype: dict[Text,Text] - """ - env = super(SdkRunnableContainer, self).env.copy() - env.update( - { - _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(), - _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(), - # TODO: Phase out the below. Propeller will set these and these are not SDK specific - _internal_config.PROJECT.env_var: _internal_config.PROJECT.get(), - _internal_config.DOMAIN.env_var: _internal_config.DOMAIN.get(), - _internal_config.NAME.env_var: _internal_config.NAME.get(), - _internal_config.VERSION.env_var: _internal_config.VERSION.get(), - } - ) - return env - - @classmethod - def get_resources( - cls, - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - ): - """ - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - """ - requests = [] - if storage_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) - ) - if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) - if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) - if memory_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) - ) - - limits = [] - if storage_limit: - limits.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit) - ) - if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) - if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) - if memory_limit: - limits.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit) - ) - - return _task_models.Resources(limits=limits, requests=requests) - - -class SdkRunnableTaskStyle(enum.Enum): - V0 = 0 - V1 = 1 - - -class SdkRunnableTask(_base_task.SdkTask, metaclass=_sdk_bases.ExtendedSdkType): - """ - This class includes the additional logic for building a task that executes in Python code. It has even more - validation checks to ensure proper behavior than it's superclasses. - - Since an SdkRunnableTask is assumed to run by hooking into Python code, we will provide additional shortcuts and - methods on this object. - """ - - def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - cache_serializable, - custom, - ): - """ - :param task_function: Function container user code. This will be executed via the SDK's engine. - :param Text task_type: string describing the task type - :param Text discovery_version: string describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param bool interruptible: Specify whether task is interruptible - :param Text deprecated: - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - :param bool discoverable: - :param datetime.timedelta timeout: - :param dict[Text, Text] environment: - :param bool cache_serializable: - :param dict[Text, T] custom: - """ - # Circular dependency - from flytekit import __version__ - - self._task_function = task_function - super(SdkRunnableTask, self).__init__( - task_type, - _task_models.TaskMetadata( - discoverable, - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - "python", - ), - timeout, - _literal_models.RetryStrategy(retries), - interruptible, - discovery_version, - deprecated, - cache_serializable, - ), - # TODO: If we end up using SdkRunnableTask for the new code, make sure this is set correctly. - _interface.TypedInterface({}, {}), - custom, - container=self._get_container_definition( - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - environment=environment, - ), - ) - self.id._name = "{}.{}".format(self.task_module, self.task_function_name) - self._has_fast_registered = False - - # TODO: Remove this in the future, I don't think we'll be using this. - self._task_style = SdkRunnableTaskStyle.V0 - - _banned_inputs = {} - _banned_outputs = {} - - @_exception_scopes.system_entry_point - def add_inputs(self, inputs): - """ - Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given - name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] inputs: names and variables - """ - self._validate_inputs(inputs) - self.interface.inputs.update(inputs) - - @classmethod - def promote_from_model(cls, base_model): - # TODO: If the task exists in this container, we should be able to retrieve it. - raise _user_exceptions.FlyteAssertion("Cannot promote a base object to a runnable task.") - - @property - def task_style(self): - return self._task_style - - @property - def task_function(self): - return self._task_function - - @property - def task_function_name(self): - """ - :rtype: Text - """ - return self.task_function.__name__ - - @property - def task_module(self): - """ - :rtype: Text - """ - return self._task_function.__module__ - - def validate(self): - super(SdkRunnableTask, self).validate() - missing_args = self._missing_mapped_inputs_outputs() - if len(missing_args) > 0: - raise _user_exceptions.FlyteAssertion( - "The task {} is invalid because not all inputs and outputs in the " - "task function definition were specified in @outputs and @inputs. " - "We are missing definitions for {}.".format(self, missing_args) - ) - - @_exception_scopes.system_entry_point - def unit_test(self, **input_map): - """ - :param dict[Text, T] input_map: Python Std input from users. We will cast these to the appropriate Flyte - literals. - :returns: Depends on the behavior of the specific task in the unit engine. - """ - return ( - _engine_loader.get_engine("unit") - .get_task(self) - .execute( - _type_helpers.pack_python_std_map_to_literal_map( - input_map, - { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }, - ) - ) - ) - - @_exception_scopes.system_entry_point - def local_execute(self, **input_map): - """ - :param dict[Text, T] input_map: Python Std input from users. We will cast these to the appropriate Flyte - literals. - :rtype: dict[Text, T] - :returns: The output produced by this task in Python standard format. - """ - return ( - _engine_loader.get_engine("local") - .get_task(self) - .execute( - _type_helpers.pack_python_std_map_to_literal_map( - input_map, - { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }, - ) - ) - ) - - def _execute_user_code(self, context, inputs): - """ - :param flytekit.engines.common.EngineContext context: - :param dict[Text, T] inputs: This variable is a bit of a misnomer, since it's both inputs and outputs. The - dictionary passed here will be passed to the user-defined function, and will have values that are a - variety of types. The T's here are Python std values for inputs. If there isn't a native Python type for - something (like Schema or Blob), they are the Flyte classes. For outputs they are OutputReferences. - (Note that these are not the same OutputReferences as in BindingData's) - :rtype: Any: the returned object from user code. - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - if self.task_style == SdkRunnableTaskStyle.V0: - return _exception_scopes.user_entry_point(self.task_function)( - ExecutionParameters( - execution_date=context.execution_date, - # TODO: it might be better to consider passing the full struct - execution_id=_six.text_type(WorkflowExecutionIdentifier.promote_from_model(context.execution_id)), - stats=context.stats, - logging=context.logging, - tmp_dir=context.working_directory, - ), - **inputs, - ) - - @_exception_scopes.system_entry_point - def execute(self, context, inputs): - """ - :param flytekit.engines.common.EngineContext context: - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( - inputs, {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in self.interface.inputs.items()} - ) - outputs_dict = { - name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) - for name, variable in _six.iteritems(self.interface.outputs) - } - - # Old style - V0: If annotations are used to define outputs, do not append outputs to the inputs dict - if not self.task_function.__annotations__ or "return" not in self.task_function.__annotations__: - inputs_dict.update(outputs_dict) - self._execute_user_code(context, inputs_dict) - return { - _constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap( - literals={k: v.sdk_value for k, v in _six.iteritems(outputs_dict)} - ) - } - - @_exception_scopes.system_entry_point - def fast_register(self, project, domain, name, digest, additional_distribution, dest_dir) -> str: - """ - The fast register call essentially hijacks the task container commandline. - Say an existing task container definition had a commandline like so: - flyte_venv pyflyte-execute --task-module app.workflows.my_workflow --task-name my_task - - The fast register command introduces a wrapper call to fast-execute the original commandline like so: - flyte_venv pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/foo/bar/12345.tar.gz -- - flyte_venv pyflyte-execute --task-module app.workflows.my_workflow --task-name my_task - - At execution time pyflyte-fast-execute will ensure the additional distribution (i.e. the fast-registered code) - exists before calling the original task commandline. - - :param Text project: The project in which to register this task. - :param Text domain: The domain in which to register this task. - :param Text name: The name to give this task. - :param Text digest: The version in which to register this task. - :param Text additional_distribution: User-specified location for remote source code distribution. - :param Text The optional location for where to install the additional distribution at runtime - :rtype: Text: Registered identifier. - """ - - original_container = self.container - container = _copy.deepcopy(original_container) - args = ["pyflyte-fast-execute", "--additional-distribution", additional_distribution] - if dest_dir: - args += ["--dest-dir", dest_dir] - args += ["--"] + container.args - container._args = args - self._container = container - - try: - registered_id = self.register(project, domain, name, digest) - except Exception: - self._container = original_container - raise - self._has_fast_registered = True - self._container = original_container - return str(registered_id) - - @property - def has_fast_registered(self) -> bool: - return self._has_fast_registered - - def _get_container_definition( - self, - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - environment=None, - cls=None, - ): - """ - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - :param dict[Text,Text] environment: - :param cls Optional[type]: Type of container to instantiate. Generally should subclass SdkRunnableContainer. - :rtype: flytekit.models.task.Container - """ - storage_limit = storage_limit or _resource_config.DEFAULT_STORAGE_LIMIT.get() - storage_request = storage_request or _resource_config.DEFAULT_STORAGE_REQUEST.get() - cpu_limit = cpu_limit or _resource_config.DEFAULT_CPU_LIMIT.get() - cpu_request = cpu_request or _resource_config.DEFAULT_CPU_REQUEST.get() - gpu_limit = gpu_limit or _resource_config.DEFAULT_GPU_LIMIT.get() - gpu_request = gpu_request or _resource_config.DEFAULT_GPU_REQUEST.get() - memory_limit = memory_limit or _resource_config.DEFAULT_MEMORY_LIMIT.get() - memory_request = memory_request or _resource_config.DEFAULT_MEMORY_REQUEST.get() - - resources = SdkRunnableContainer.get_resources( - storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit - ) - - return (cls or SdkRunnableContainer)( - command=[], - args=[ - "pyflyte-execute", - "--task-module", - self.task_module, - "--task-name", - self.task_function_name, - "--inputs", - "{{.input}}", - "--output-prefix", - "{{.outputPrefix}}", - "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", - ], - resources=resources, - env=environment, - config={}, - ) - - def _validate_inputs(self, inputs): - """ - This method should be overridden in sub-classes that intend to do additional checks on inputs. If validation - fails, this function should raise an informative exception. - :param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - super(SdkRunnableTask, self)._validate_inputs(inputs) - for k, v in _six.iteritems(inputs): - if not self._is_argname_in_function_definition(k): - raise _user_exceptions.FlyteValidationException( - "The input named '{}' was not specified in the task function. Therefore, this input cannot be " - "provided to the task.".format(k) - ) - if _type_helpers.get_sdk_type_from_literal_type(v.type) in type(self)._banned_inputs: - raise _user_exceptions.FlyteValidationException( - "The input '{}' is not an accepted input type.".format(v) - ) - - def _validate_outputs(self, outputs): - """ - This method should be overridden in sub-classes that intend to do additional checks on outputs. If validation - fails, this function should raise an informative exception. - :param dict[Text, flytekit.models.interface.Variable] outputs: Output variables to validate - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - super(SdkRunnableTask, self)._validate_outputs(outputs) - for k, v in _six.iteritems(outputs): - if not self._is_argname_in_function_definition(k): - raise _user_exceptions.FlyteValidationException( - "The output named '{}' was not specified in the task function. Therefore, this output cannot be " - "provided to the task.".format(k) - ) - if _type_helpers.get_sdk_type_from_literal_type(v.type) in type(self)._banned_outputs: - raise _user_exceptions.FlyteValidationException( - "The output '{}' is not an accepted output type.".format(v) - ) - - def _get_kwarg_inputs(self): - # Trim off first parameter as it is reserved for workflow_parameters - return set(_getargspec(self.task_function).args[1:]) - - def _is_argname_in_function_definition(self, key): - return key in self._get_kwarg_inputs() - - def _missing_mapped_inputs_outputs(self): - # Trim off first parameter as it is reserved for workflow_parameters - args = self._get_kwarg_inputs() - inputs_and_outputs = set(self.interface.outputs.keys()) | set(self.interface.inputs.keys()) - return args ^ inputs_and_outputs diff --git a/flytekit/common/tasks/sidecar_task.py b/flytekit/common/tasks/sidecar_task.py deleted file mode 100644 index 15cb62d760..0000000000 --- a/flytekit/common/tasks/sidecar_task.py +++ /dev/null @@ -1,245 +0,0 @@ -import six as _six -from flyteidl.core import tasks_pb2 as _core_task -from google.protobuf.json_format import MessageToDict as _MessageToDict - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models -from flytekit.plugins import k8s as _lazy_k8s - - -class SdkSidecarTask(_sdk_runnable.SdkRunnableTask, metaclass=_sdk_bases.ExtendedSdkType): - - """ - This class includes the additional logic for building a task that executes as a Sidecar Job. - - """ - - def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - cache_serializable, - pod_spec=None, - primary_container_name=None, - annotations=None, - labels=None, - ): - """ - :param _sdk_runnable.SdkRunnableTask sdk_runnable_task: - :param generated_pb2.PodSpec pod_spec: - :param Text primary_container_name: - :param dict[Text, Text] annotations: - :param dict[Text, Text] labels: - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - if not pod_spec: - raise _user_exceptions.FlyteValidationException("A pod spec cannot be undefined") - if not primary_container_name: - raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") - - super(SdkSidecarTask, self).__init__( - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - cache_serializable, - custom=None, - ) - - self.reconcile_partial_pod_spec_and_task(pod_spec, primary_container_name, annotations, labels) - - def reconcile_partial_pod_spec_and_task(self, pod_spec, primary_container_name, annotations=None, labels=None): - """ - Assigns the custom field as a the reconciled primary container and pod spec definition. - :param _sdk_runnable.SdkRunnableTask sdk_runnable_task: - :param generated_pb2.PodSpec pod_spec: - :param Text primary_container_name: - :param dict[Text, Text] annotations: - :param dict[Text, Text] labels: - :rtype: SdkSidecarTask - """ - - # First, insert a placeholder primary container if it is not defined in the pod spec. - containers = pod_spec.containers - primary_exists = False - for container in containers: - if container.name == primary_container_name: - primary_exists = True - break - if not primary_exists: - containers.extend([_lazy_k8s.io.api.core.v1.generated_pb2.Container(name=primary_container_name)]) - - final_containers = [] - for container in containers: - # In the case of the primary container, we overwrite specific container attributes with the default values - # used in an SDK runnable task. - if container.name == primary_container_name: - container.image = self._container.image - # clear existing commands - del container.command[:] - container.command.extend(self._container.command) - # also clear existing args - del container.args[:] - container.args.extend(self._container.args) - - resource_requirements = _lazy_k8s.io.api.core.v1.generated_pb2.ResourceRequirements() - for resource in self._container.resources.limits: - resource_requirements.limits[ - _core_task.Resources.ResourceName.Name(resource.name).lower() - ].CopyFrom(_lazy_k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity(string=resource.value)) - for resource in self._container.resources.requests: - resource_requirements.requests[ - _core_task.Resources.ResourceName.Name(resource.name).lower() - ].CopyFrom(_lazy_k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity(string=resource.value)) - if resource_requirements.ByteSize(): - # Important! Only copy over resource requirements if they are non-empty. - container.resources.CopyFrom(resource_requirements) - - del container.env[:] - container.env.extend( - [ - _lazy_k8s.io.api.core.v1.generated_pb2.EnvVar(name=key, value=val) - for key, val in _six.iteritems(self._container.env) - ] - ) - - final_containers.append(container) - - del pod_spec.containers[:] - pod_spec.containers.extend(final_containers) - - sidecar_job_plugin = _task_models.SidecarJob( - pod_spec=pod_spec, - primary_container_name=primary_container_name, - annotations=annotations, - labels=labels, - ).to_flyte_idl() - - self.assign_custom_and_return(_MessageToDict(sidecar_job_plugin)) - - -class SdkDynamicSidecarTask( - _sdk_dynamic.SdkDynamicTaskMixin, - SdkSidecarTask, - metaclass=_sdk_bases.ExtendedSdkType, -): - - """ - This class includes the additional logic for building a task that runs as - a Sidecar Job and executes parent-child tasks. - - """ - - def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - allowed_failure_ratio, - max_concurrency, - environment, - cache_serializable, - pod_spec=None, - primary_container_name=None, - annotations=None, - labels=None, - ): - """ - :param task_function: Function container user code. This will be executed via the SDK's engine. - :param Text task_type: string describing the task type - :param Text discovery_version: string describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param bool interruptible: Whether or not task is interruptible - :param Text deprecated: - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - :param bool discoverable: - :param datetime.timedelta timeout: - :param float allowed_failure_ratio: - :param int max_concurrency: - :param dict[Text, Text] environment: - :param bool cache_serializable: - :param generated_pb2.PodSpec pod_spec: - :param Text primary_container_name: - :param dict[Text, Text] annotations: - :param dict[Text, Text] labels: - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - - SdkSidecarTask.__init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - cache_serializable, - pod_spec=pod_spec, - primary_container_name=primary_container_name, - annotations=annotations, - labels=labels, - ) - - _sdk_dynamic.SdkDynamicTaskMixin.__init__(self, allowed_failure_ratio, max_concurrency) diff --git a/flytekit/common/tasks/spark_task.py b/flytekit/common/tasks/spark_task.py deleted file mode 100644 index f3f55f211e..0000000000 --- a/flytekit/common/tasks/spark_task.py +++ /dev/null @@ -1,208 +0,0 @@ -import typing - -try: - from inspect import getfullargspec as _getargspec -except ImportError: - from inspect import getargspec as _getargspec - -import copy as _copy -import hashlib as _hashlib -import json as _json -import os as _os -import sys as _sys - -import six as _six -from google.protobuf.json_format import MessageToDict as _MessageToDict - -from flytekit.bin import entrypoint as _entrypoint -from flytekit.common import constants as _constants -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.types import helpers as _type_helpers -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models -from flytekit.plugins import pyspark as _pyspark - - -class GlobalSparkContext(object): - _SPARK_CONTEXT = None - _SPARK_SESSION = None - - @classmethod - def get_spark_context(cls): - return cls._SPARK_CONTEXT - - @classmethod - def get_spark_session(cls): - return cls._SPARK_SESSION - - def __enter__(self): - GlobalSparkContext._SPARK_CONTEXT = _pyspark.SparkContext() - GlobalSparkContext._SPARK_SESSION = _pyspark.sql.SparkSession.builder.appName( - "Flyte Spark SQL Context" - ).getOrCreate() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - GlobalSparkContext._SPARK_CONTEXT.stop() - GlobalSparkContext._SPARK_CONTEXT = None - return False - - -class SdkRunnableSparkContainer(_sdk_runnable.SdkRunnableContainer): - @property - def args(self): - """ - Override args to remove the injection of command prefixes - :rtype: list[Text] - """ - return self._args - - -class SdkSparkTask(_sdk_runnable.SdkRunnableTask): - """ - This class includes the additional logic for building a task that executes as a Spark Job. - - """ - - def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - discoverable, - timeout, - spark_type, - spark_conf, - hadoop_conf, - environment, - cache_serializable, - ): - """ - :param task_function: Function container user code. This will be executed via the SDK's engine. - :param Text task_type: string describing the task type - :param Text discovery_version: string describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param bool interruptible: Whether or not task is interruptible - :param Text deprecated: - :param bool discoverable: - :param datetime.timedelta timeout: - :param dict[Text,Text] spark_conf: - :param dict[Text,Text] hadoop_conf: - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - :param bool cache_serializable: - """ - - spark_exec_path = _os.path.abspath(_entrypoint.__file__) - if spark_exec_path.endswith(".pyc"): - spark_exec_path = spark_exec_path[:-1] - - self._spark_job = _task_models.SparkJob( - spark_conf=spark_conf, - hadoop_conf=hadoop_conf, - application_file="local://" + spark_exec_path, - executor_path=_sys.executable, - main_class="", - spark_type=spark_type, - ) - super(SdkSparkTask, self).__init__( - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - "", - "", - "", - "", - "", - "", - "", - "", - discoverable, - timeout, - environment, - cache_serializable, - _MessageToDict(self._spark_job.to_flyte_idl()), - ) - - @_exception_scopes.system_entry_point - def execute(self, context, inputs): - """ - :param flytekit.engines.common.EngineContext context: - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( - inputs, - {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, - ) - outputs_dict = { - name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) - for name, variable in _six.iteritems(self.interface.outputs) - } - - inputs_dict.update(outputs_dict) - - with GlobalSparkContext(): - _exception_scopes.user_entry_point(self.task_function)( - _sdk_runnable.ExecutionParameters( - execution_date=context.execution_date, - tmp_dir=context.working_directory, - stats=context.stats, - execution_id=context.execution_id, - logging=context.logging, - ), - GlobalSparkContext.get_spark_context(), - **inputs_dict - ) - return { - _constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap( - literals={k: v.sdk_value for k, v in _six.iteritems(outputs_dict)} - ) - } - - @property - def spark_conf(self): - return self._spark_job.spark_conf - - @property - def hadoop_conf(self): - return self._spark_job.hadoop_conf - - def _get_container_definition(self, **kwargs): - """ - :rtype: SdkRunnableSparkContainer - """ - return super(SdkSparkTask, self)._get_container_definition(cls=SdkRunnableSparkContainer, **kwargs) - - def _get_kwarg_inputs(self): - # Trim off first two parameters as they are reserved for workflow_parameters and spark_context - return set(_getargspec(self.task_function).args[2:]) - - def with_overrides( - self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None - ): - """ - Creates a new SparkJob instance with the modified configuration or timeouts - """ - tk = _copy.deepcopy(self) - tk._spark_job = self._spark_job.with_overrides(new_spark_conf, new_hadoop_conf) - tk._custom = _MessageToDict(tk._spark_job.to_flyte_idl()) - - salt = _hashlib.md5(_json.dumps(tk.custom, sort_keys=True).encode("utf-8")).hexdigest() - tk._id._name = "{}-{}".format(self._id.name, salt) - # We are overriding the platform name creation to prevent problems in dynamic - tk.assign_name(tk._id._name) - - return tk diff --git a/flytekit/common/tasks/task.py b/flytekit/common/tasks/task.py deleted file mode 100644 index ba55399382..0000000000 --- a/flytekit/common/tasks/task.py +++ /dev/null @@ -1,423 +0,0 @@ -import hashlib as _hashlib -import json as _json -import logging as _logging -import uuid as _uuid - -import six as _six -from google.protobuf import json_format as _json_format -from google.protobuf import struct_pb2 as _struct - -from flytekit.common import interface as _interfaces -from flytekit.common import nodes as _nodes -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import workflow_execution as _workflow_execution -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import hash as _hash_mixin -from flytekit.common.mixins import launchable as _launchable_mixin -from flytekit.common.mixins import registerable as _registerable -from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import auth as _auth_config -from flytekit.configuration import internal as _internal_config -from flytekit.configuration import sdk as _sdk_config -from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import common as _common_model -from flytekit.models import execution as _admin_execution_models -from flytekit.models import task as _task_model -from flytekit.models.admin import common as _admin_common -from flytekit.models.core import identifier as _identifier_model -from flytekit.models.core import workflow as _workflow_model - - -class SdkTask( - _hash_mixin.HashOnReferenceMixin, - _registerable.RegisterableEntity, - _launchable_mixin.LaunchableEntity, - _task_model.TaskTemplate, - metaclass=_sdk_bases.ExtendedSdkType, -): - def __init__( - self, type, metadata, interface, custom, container=None, task_type_version=0, security_context=None, config=None - ): - """ - :param Text type: This is used to define additional extensions for use by Propeller or SDK. - :param TaskMetadata metadata: This contains information needed at runtime to determine behavior such as - whether or not outputs are discoverable, timeouts, and retries. - :param flytekit.common.interface.TypedInterface interface: The interface definition for this task. - :param dict[Text, T] custom: Arbitrary type for use by plugins. - :param Container container: Provides the necessary entrypoint information for execution. For instance, - a Container might be specified with the necessary command line arguments. - :param int task_type_version: Specific version of this task type used by plugins to potentially modify - execution behavior or serialization. - :param _SecurityContext security_context: - """ - # TODO: Remove the identifier portion and fill in with local values. - super(SdkTask, self).__init__( - _identifier.Identifier( - _identifier_model.ResourceType.TASK, - _internal_config.PROJECT.get(), - _internal_config.DOMAIN.get(), - _uuid.uuid4().hex, - _internal_config.VERSION.get(), - ), - type, - metadata, - interface, - custom, - container=container, - task_type_version=task_type_version, - security_context=security_context, - config=config, - ) - - @property - def interface(self): - """ - :rtype: flytekit.common.interface.TypedInterface - """ - return super(SdkTask, self).interface - - @property - def resource_type(self): - """ - Integer from _identifier.ResourceType enum - :rtype: int - """ - return _identifier_model.ResourceType.TASK - - @property - def entity_type_text(self): - """ - :rtype: Text - """ - return "Task" - - @classmethod - def promote_from_model(cls, base_model): - """ - :param flytekit.models.task.TaskTemplate base_model: - :rtype: SdkTask - """ - t = cls( - type=base_model.type, - metadata=base_model.metadata, - interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), - custom=base_model.custom, - container=base_model.container, - task_type_version=base_model.task_type_version, - ) - # Override the newly generated name if one exists in the base model - if not base_model.id.is_empty: - t._id = _identifier.Identifier.promote_from_model(base_model.id) - - return t - - def assign_custom_and_return(self, custom): - self._custom = custom - return self - - def assign_type_and_return(self, new_type): - self._type = new_type - return self - - @_exception_scopes.system_entry_point - def __call__(self, *args, **input_map): - """ - :param list[T] args: Do not specify. Kwargs only are supported for this function. - :param dict[str, T] input_map: Map of inputs. Can be statically defined or OutputReference links. - :rtype: flytekit.common.nodes.SdkNode - """ - if len(args) > 0: - raise _user_exceptions.FlyteAssertion( - "When adding a task as a node in a workflow, all inputs must be specified with kwargs only. We " - "detected {} positional args.".format(len(args)) - ) - - bindings, upstream_nodes = self.interface.create_bindings_for_inputs(input_map) - - # TODO: Remove DEADBEEF - # One thing to note - this function is not overloaded at the SdkRunnableTask layer, which means 'self' here - # will sometimes refer to an object that can be executed locally, and other times will refer to something - # that cannot (ie a pure SdkTask object, fetched from Admin for instance). - return _nodes.SdkNode( - id=None, - metadata=_workflow_model.NodeMetadata( - "DEADBEEF", - self.metadata.timeout, - self.metadata.retries, - self.metadata.interruptible, - ), - bindings=sorted(bindings, key=lambda b: b.var), - upstream_nodes=upstream_nodes, - sdk_task=self, - ) - - @_exception_scopes.system_entry_point - def register(self, project, domain, name, version): - """ - :param Text project: The project in which to register this task. - :param Text domain: The domain in which to register this task. - :param Text name: The name to give this task. - :param Text version: The version in which to register this task. - """ - # TODO: Revisit the notion of supplying the project, domain, name, version, as opposed to relying on the - # current ID. - self.validate() - id_to_register = _identifier.Identifier(_identifier_model.ResourceType.TASK, project, domain, name, version) - old_id = self.id - - client = _flyte_engine.get_client() - try: - self._id = id_to_register - client.create_task(id_to_register, _task_model.TaskSpec(self)) - self._id = old_id - self._has_registered = True - return str(id_to_register) - except _user_exceptions.FlyteEntityAlreadyExistsException: - pass - except Exception: - self._id = old_id - raise - - @_exception_scopes.system_entry_point - def serialize(self): - """ - :rtype: flyteidl.admin.task_pb2.TaskSpec - """ - return _task_model.TaskSpec(self).to_flyte_idl() - - @classmethod - @_exception_scopes.system_entry_point - def fetch(cls, project, domain, name, version): - """ - This function uses the engine loader to call create a hydrated task from Admin. - :param Text project: - :param Text domain: - :param Text name: - :param Text version: - :rtype: SdkTask - """ - task_id = _identifier.Identifier(_identifier_model.ResourceType.TASK, project, domain, name, version) - admin_task = _flyte_engine.get_client().get_task(task_id) - - sdk_task = cls.promote_from_model(admin_task.closure.compiled_task.template) - sdk_task._id = task_id - sdk_task._has_registered = True - return sdk_task - - @classmethod - @_exception_scopes.system_entry_point - def fetch_latest(cls, project, domain, name): - """ - This function uses the engine loader to call create a latest hydrated task from Admin. - :param Text project: - :param Text domain: - :param Text name: - :rtype: SdkTask - """ - named_task = _common_model.NamedEntityIdentifier(project, domain, name) - client = _flyte_engine.get_client() - task_list, _ = client.list_tasks_paginated( - named_task, - limit=1, - sort_by=_admin_common.Sort("created_at", _admin_common.Sort.Direction.DESCENDING), - ) - admin_task = task_list[0] if task_list else None - - if not admin_task: - raise _user_exceptions.FlyteEntityNotExistException("Named task {} not found".format(named_task)) - sdk_task = cls.promote_from_model(admin_task.closure.compiled_task.template) - sdk_task._id = admin_task.id - return sdk_task - - @_exception_scopes.system_entry_point - def validate(self): - pass - - @_exception_scopes.system_entry_point - def add_inputs(self, inputs): - raise _user_exceptions.FlyteUserException("You can not add inputs to this task") - - @_exception_scopes.system_entry_point - def add_outputs(self, outputs): - """ - Adds the outputs to this task. This can be called multiple times, but it will fail if an output with a given - name is added more than once, a name collides with an input, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] outputs: names and variables to add as outputs - to this task - """ - self._validate_outputs(outputs) - self.interface.outputs.update(outputs) - - def _validate_inputs(self, inputs): - """ - This method should be overridden in sub-classes that intend to do additional checks on inputs. If validation - fails, this function should raise an informative exception. - :param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - for k, v in _six.iteritems(inputs): - if k in self.interface.inputs: - raise _user_exceptions.FlyteValidationException( - "An input with name '{}' is already defined. Redefinition is not allowed.".format(k) - ) - if k in self.interface.outputs: - raise _user_exceptions.FlyteValidationException( - "An output with name '{}' is already defined. Therefore '{}' can't be defined as an " - "input".format(k, v) - ) - - def _validate_outputs(self, outputs): - """ - This method should be overridden in sub-classes that intend to do additional checks on outputs. If validation - fails, this function should raise an informative exception. - :param dict[Text, flytekit.models.interface.Variable] outputs: Output variables to validate - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - for k, v in _six.iteritems(outputs): - if k in self.interface.outputs: - raise _user_exceptions.FlyteValidationException( - "An output with name '{}' is already defined. Redefinition is not allowed.".format(k) - ) - if k in self.interface.inputs: - raise _user_exceptions.FlyteValidationException( - "An input with name '{}' is already defined. Therefore '{}' can't be defined as an " - "input".format(k, v) - ) - - def __repr__(self): - return "Flyte {task_type}: {interface}".format(task_type=self.type, interface=self.interface) - - def _python_std_input_map_to_literal_map(self, inputs): - """ - :param dict[Text,Any] inputs: A dictionary of Python standard inputs that will be type-checked and compiled - to a LiteralMap - :rtype: flytekit.models.literals.LiteralMap - """ - return _type_helpers.pack_python_std_map_to_literal_map( - inputs, - {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, - ) - - def _produce_deterministic_version(self, version=None): - """ - :param Text version: - :return Text: - """ - - if self.container is not None and self.container.data_loading_config is None: - # Only in the case of raw container tasks (which are the only valid tasks with container definitions that - # can assign a client-side task version) their data config will be None. - raise ValueError("Client-side task versions are not supported for {} task type".format(self.type)) - if version is not None: - return version - custom = _json_format.Parse(_json.dumps(self.custom, sort_keys=True), _struct.Struct()) if self.custom else None - - # The task body is the entirety of the task template MINUS the identifier. The identifier is omitted because - # 1) this method is used to compute the version portion of the identifier and - # 2 ) the SDK will actually generate a unique name on every task instantiation which is not great for - # the reproducibility this method attempts. - task_body = ( - self.type, - self.metadata.to_flyte_idl().SerializeToString(deterministic=True), - self.interface.to_flyte_idl().SerializeToString(deterministic=True), - custom, - ) - return _hashlib.md5(str(task_body).encode("utf-8")).hexdigest() - - @_exception_scopes.system_entry_point - def register_and_launch(self, project, domain, name, version=None, inputs=None): - """ - :param Text project: The project in which to register and launch this task. - :param Text domain: The domain in which to register and launch this task. - :param Text name: The name to give this task. - :param Text version: The version in which to register this task - :param dict[Text, Any] inputs: A dictionary of Python standard inputs that will be type-checked, then compiled - to a LiteralMap. - - :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution - """ - self.validate() - version = self._produce_deterministic_version(version) - self.register(project, domain, name, version) - return self.launch(project, domain, inputs=inputs) - - @_exception_scopes.system_entry_point - def launch_with_literals( - self, - project, - domain, - literal_inputs, - name=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - auth_role=None, - ): - """ - Launches a single task execution and returns the execution identifier. - :param Text project: - :param Text domain: - :param flytekit.models.literals.LiteralMap literal_inputs: Inputs to the execution. - :param Text name: [Optional] If specified, an execution will be created with this name. Note: the name must - be unique within the context of the project and domain. - :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these - are the notifications that will be honored for this execution. An empty list signals to disable all - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: - :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution - """ - disable_all = notification_overrides == [] - if disable_all: - notification_overrides = None - else: - notification_overrides = _admin_execution_models.NotificationList(notification_overrides or []) - disable_all = None - - # Unlike regular workflow executions, single task executions must always specify an auth role, since there isn't - # any existing launch plan with a bound auth role to fall back on. - if auth_role is None: - assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() - kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() - - if not (assumable_iam_role or kubernetes_service_account): - _logging.warning( - "Using deprecated `role` from config. " - "Please update your config to use `assumable_iam_role` instead" - ) - assumable_iam_role = _sdk_config.ROLE.get() - auth_role = _common_model.AuthRole( - assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account, - ) - - client = _flyte_engine.get_client() - try: - # TODO(katrogan): Add handling to register the underlying task if it's not already. - exec_id = client.create_execution( - project, - domain, - name, - _admin_execution_models.ExecutionSpec( - self.id, - _admin_execution_models.ExecutionMetadata( - _admin_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - "sdk", # TODO: get principle - 0, # TODO: Detect nesting - ), - notifications=notification_overrides, - disable_all=disable_all, - labels=label_overrides, - annotations=annotation_overrides, - auth_role=auth_role, - ), - literal_inputs, - ) - except _user_exceptions.FlyteEntityAlreadyExistsException: - exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) - execution = client.get_execution(exec_id) - return _workflow_execution.SdkWorkflowExecution.promote_from_model(execution) diff --git a/flytekit/common/types/__init__.py b/flytekit/common/types/__init__.py deleted file mode 100644 index b2e2a3729a..0000000000 --- a/flytekit/common/types/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -This package contains the runtime logic wrapping our type models. -""" diff --git a/flytekit/common/types/base_sdk_types.py b/flytekit/common/types/base_sdk_types.py deleted file mode 100644 index ab12969865..0000000000 --- a/flytekit/common/types/base_sdk_types.py +++ /dev/null @@ -1,141 +0,0 @@ -import abc as _abc - -from flyteidl.core.literals_pb2 import Literal - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import common as _common_models -from flytekit.models import literals as _literal_models - - -class FlyteSdkType(_sdk_bases.ExtendedSdkType, metaclass=_common_models.FlyteABCMeta): - @_abc.abstractmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - pass - - @_abc.abstractmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - pass - - @_abc.abstractmethod - def from_string(cls, string_value): - """ - :param Text string_value: It is up to each individual object to implement this. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - pass - - @_abc.abstractmethod - def promote_from_model(cls, literal): - """ - :param flytekit.models.literals.Literal literal: - :rtype: FlyteSdkValue - """ - pass - - @_abc.abstractmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - pass - - def __hash__(cls): - return hash(cls.to_flyte_literal_type()) - - -class FlyteSdkValue(_literal_models.Literal, metaclass=FlyteSdkType): - @classmethod - def from_flyte_idl(cls, pb2_object: Literal): - """ - :param flyteidl.core.literals_pb2.Literal pb2_object: - :rtype: FlyteSdkValue - """ - literal = _literal_models.Literal.from_flyte_idl(pb2_object) - if literal.scalar is not None and literal.scalar.none_type is not None: - return Void() - return cls.promote_from_model(literal) - - @_abc.abstractmethod - def to_python_std(self): - pass - - -class InstantiableType(FlyteSdkType, metaclass=_common_models.FlyteABCMeta): - @_abc.abstractmethod - def __call__(cls, *args, **kwargs): - """ - TODO: Figure out generics for type hinting. - - :rtype: T - """ - return super(InstantiableType, cls).__call__(*args, **kwargs) - - -class Void(FlyteSdkValue): - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return True - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - return cls() - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - raise _user_exceptions.FlyteAssertion( - "A Void type does not have a literal type and cannot be used in this " "manner." - ) - - @classmethod - def promote_from_model(cls, _): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal _: - :rtype: Void - """ - return cls() - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Void" - - def __init__(self): - super(Void, self).__init__(scalar=_literal_models.Scalar(none_type=_literal_models.Void())) - - def to_python_std(self): - """ - :rtype: NoneType - """ - return None - - def short_string(self): - """ - :rtype: Text - """ - return "Void()" diff --git a/flytekit/common/types/blobs.py b/flytekit/common/types/blobs.py deleted file mode 100644 index 7870cb75bd..0000000000 --- a/flytekit/common/types/blobs.py +++ /dev/null @@ -1,465 +0,0 @@ -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.common.types.impl import blobs as _blob_impl -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types -from flytekit.models.core import types as _core_types - - -class BlobInstantiator(_base_sdk_types.InstantiableType): - @staticmethod - def create_at_known_location(location): - """ - :param Text location: - :rtype: flytekit.common.types.impl.blobs.Blob - """ - return _blob_impl.Blob.create_at_known_location(location, mode="wb") - - @staticmethod - def fetch(remote_path, local_path=None): - """ - :param Text remote_path: - :param Text local_path: [Optional] If specified, the Blob is copied to this location. If specified, - this location is NOT managed and the blob will not be cleaned up upon exit. - :rtype: flytekit.common.types.impl.blobs.Blob - """ - return _blob_impl.Blob.fetch(remote_path, mode="rb", local_path=local_path) - - def __call__(cls, *args, **kwargs): - """ - TODO: Is there a better way to deal with this? - - We want the behavior of Types.Blob() returns a _blob_impl.Blob, but also to be able to use this object to - wrap a _blob_impl.Blob via Types.Blob(_blob_impl.Blob()) for serialization, type checking, etc.. - - :rtype: flytekit.common.types.impl.blobs.Blob - """ - if not args and not kwargs: - return _blob_impl.Blob.create_at_any_location(mode="wb") - else: - return super(BlobInstantiator, cls).__call__(*args, **kwargs) - - -# TODO: Make blobs and schemas pluggable -class Blob(_base_sdk_types.FlyteSdkValue, metaclass=BlobInstantiator): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: Blob - """ - if not string_value: - _user_exceptions.FlyteValueException(string_value, "Cannot create a Blob from the provided path value.") - return cls(_blob_impl.Blob.from_string(string_value, mode="rb")) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif isinstance(t_value, _blob_impl.Blob): - blob = t_value - else: - blob = _blob_impl.Blob.from_python_std(t_value) - return cls(blob) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType( - blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE) - ) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Blob - """ - return cls(_blob_impl.Blob.promote_from_model(literal_model.scalar.blob)) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Blob" - - def __init__(self, value): - """ - :param flytekit.common.types.impl.blobs.Blob value: Blob value to wrap - """ - super(Blob, self).__init__(scalar=_literals.Scalar(blob=value)) - - def to_python_std(self): - """ - :rtype: flytekit.common.types.impl.blobs.Blob - """ - return self.scalar.blob - - def short_string(self): - """ - :rtype: Text - """ - return "Blob(uri={}{})".format( - self.scalar.blob.uri, - ", format={}".format(self.scalar.blob.metadata.type.format) - if self.scalar.blob.metadata.type.format - else "", - ) - - -class MultiPartBlobInstantiator(_base_sdk_types.InstantiableType): - @staticmethod - def create_at_known_location(location): - """ - :param Text location: - :rtype: flytekit.common.types.impl.blobs.MultiPartBlob - """ - return _blob_impl.MultiPartBlob.create_at_known_location(location, mode="wb") - - @staticmethod - def fetch(remote_path, local_path=None): - """ - :param Text remote_path: - :param Text local_path: [Optional] If specified, the MultiPartBlob is copied to this location. If specified, - this location is NOT managed and the blob will not be cleaned up upon exit. - :rtype: flytekit.common.types.impl.blobs.MultiPartBlob - """ - return _blob_impl.MultiPartBlob.fetch(remote_path, mode="rb", local_path=local_path) - - def __call__(cls, *args, **kwargs): - """ - TODO: Is there a better way to deal with this? - - We want the behavior of Types.MultiPartBlob() returns a _blob_impl.MultiPartBlob, but also to be able to use - this object to wrap a _blob_impl.MultiPartBlob via Types.MultiPartBlob(_blob_impl.MultiPartBlob()) for - serialization, type checking, etc.. - - :rtype: flytekit.common.types.impl.blobs.MultiPartBlob - """ - if not args and not kwargs: - return _blob_impl.MultiPartBlob.create_at_any_location(mode="wb") - else: - return super(MultiPartBlobInstantiator, cls).__call__(*args, **kwargs) - - -class MultiPartBlob(_base_sdk_types.FlyteSdkValue, metaclass=MultiPartBlobInstantiator): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: MultiPartBlob - """ - if not string_value: - _user_exceptions.FlyteValueException( - string_value, - "Cannot create a MultiPartBlob from the provided path " "value.", - ) - return cls(_blob_impl.MultiPartBlob.from_string(string_value, mode="rb")) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif isinstance(t_value, _blob_impl.MultiPartBlob): - blob = t_value - else: - blob = _blob_impl.MultiPartBlob.from_python_std(t_value) - return cls(blob) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType( - blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) - ) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: MultiPartBlob - """ - return cls(_blob_impl.MultiPartBlob.promote_from_model(literal_model.scalar.blob)) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "MultiPartBlob" - - def __init__(self, value): - """ - :param flytekit.common.types.impl.blobs.MultiPartBlob value: Blob value to wrap - """ - super(MultiPartBlob, self).__init__(scalar=_literals.Scalar(blob=value)) - - def to_python_std(self): - """ - :rtype: flytekit.common.types.impl.blobs.MultiPartBlob - """ - return self.scalar.blob - - def short_string(self): - """ - :rtype: Text - """ - return "MultiPartBlob(uri={}{})".format( - self.scalar.blob.uri, - ", format={}".format(self.scalar.blob.metadata.type.format) - if self.scalar.blob.metadata.type.format - else "", - ) - - -class CsvInstantiator(BlobInstantiator): - @staticmethod - def create_at_known_location(location): - """ - :param Text location: - :rtype: flytekit.common.types.impl.blobs.CSV - """ - return _blob_impl.Blob.create_at_known_location(location, mode="w", format="csv") - - @staticmethod - def fetch(remote_path, local_path=None): - """ - :param Text remote_path: - :param Text local_path: [Optional] If specified, the MultiPartBlob is copied to this location. If specified, - this location is NOT managed and the blob will not be cleaned up upon exit. - :rtype: flytekit.common.types.impl.blobs.CSV - """ - return _blob_impl.Blob.fetch(remote_path, local_path=local_path, mode="r", format="csv") - - def __call__(cls, *args, **kwargs): - """ - TODO: Is there a better way to deal with this? - - We want the behavior of Types.CSV() returns a _blob_impl.CSV, but also to be able to use - this object to wrap a _blob_impl.CSV via Types.CSV(_blob_impl.CSV()) for - serialization, type checking, etc.. - - :rtype: flytekit.common.types.impl.blobs.CSV - """ - if not args and not kwargs: - return _blob_impl.Blob.create_at_any_location(mode="w", format="csv") - else: - return super(CsvInstantiator, cls).__call__(*args, **kwargs) - - -class CSV(Blob, metaclass=CsvInstantiator): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: CSV - """ - if not string_value: - _user_exceptions.FlyteValueException(string_value, "Cannot create a CSV from the provided path value.") - return cls(_blob_impl.Blob.from_string(string_value, format="csv", mode="r")) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif isinstance(t_value, _blob_impl.Blob): - if t_value.metadata.type.format != "csv": - raise _user_exceptions.FlyteValueException(t_value, "Blob is in incorrect format. Expected CSV.") - blob = t_value - else: - blob = _blob_impl.Blob.from_python_std(t_value, format="csv", mode="w") - return cls(blob) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, - ) - ) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: CSV - """ - return cls(_blob_impl.Blob.promote_from_model(literal_model.scalar.blob, mode="r")) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "CSV" - - def __init__(self, value): - """ - :param flytekit.common.types.impl.blobs.Blob value: CSV blob value to wrap - """ - super(CSV, self).__init__(value) - - -class MultiPartCsvInstantiator(MultiPartBlobInstantiator): - @staticmethod - def create_at_known_location(location): - """ - :param Text location: - :rtype: flytekit.common.types.impl.blobs.MultiPartBlob - """ - return _blob_impl.MultiPartBlob.create_at_known_location(location, mode="w", format="csv") - - @staticmethod - def fetch(remote_path, local_path=None): - """ - :param Text remote_path: - :param Text local_path: [Optional] If specified, the MultiPartCSV is copied to this location. If specified, - this location is NOT managed and the blob will not be cleaned up upon exit. - :rtype: flytekit.common.types.impl.blobs.MultiPartCSV - """ - return _blob_impl.MultiPartBlob.fetch(remote_path, local_path=local_path, mode="r", format="csv") - - def __call__(cls, *args, **kwargs): - """ - TODO: Is there a better way to deal with this? - - We want the behavior of Types.MultiPartCSV() returns a _blob_impl.MultiPartCSV, but also to be able to use - this object to wrap a _blob_impl.MultiPartCSV via Types.MultiPartCSV(_blob_impl.MultiPartCSV()) for - serialization, type checking, etc.. - - :rtype: flytekit.common.types.impl.blobs.MultiPartCSV - """ - if not args and not kwargs: - return _blob_impl.MultiPartBlob.create_at_any_location(mode="w", format="csv") - else: - return super(MultiPartCsvInstantiator, cls).__call__(*args, **kwargs) - - -class MultiPartCSV(MultiPartBlob, metaclass=MultiPartCsvInstantiator): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: MultiPartCSV - """ - if not string_value: - _user_exceptions.FlyteValueException( - string_value, - "Cannot create a MultiPartCSV from the provided path value.", - ) - return cls(_blob_impl.MultiPartBlob.from_string(string_value, format="csv", mode="r")) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif isinstance(t_value, _blob_impl.MultiPartBlob): - if t_value.metadata.type.format != "csv": - raise _user_exceptions.FlyteValueException( - t_value, "Multi Part Blob is in incorrect format. Expected CSV." - ) - blob = t_value - else: - blob = _blob_impl.MultiPartBlob.from_python_std(t_value, format="csv", mode="w") - return cls(blob) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) - ) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: MultiPartCSV - """ - return cls(_blob_impl.MultiPartBlob.promote_from_model(literal_model.scalar.blob, mode="r")) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "MultiPartCSV" - - def __init__(self, value): - """ - :param flytekit.common.types.impl.blobs.MultiPartBlob value: MultiPartBlob value to wrap - """ - super(MultiPartCSV, self).__init__(value) diff --git a/flytekit/common/types/containers.py b/flytekit/common/types/containers.py deleted file mode 100644 index 9267273b1e..0000000000 --- a/flytekit/common/types/containers.py +++ /dev/null @@ -1,157 +0,0 @@ -import json as _json - -import six as _six - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types - - -class CollectionType(_base_sdk_types.FlyteSdkType): - pass - - -class TypedCollectionType(CollectionType): - @property - def sub_type(cls): - """ - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - return cls._sub_type - - def __eq__(cls, other): - return hasattr(other, "sub_type") and cls.sub_type == other.sub_type - - def __hash__(cls): - # Python 3 checks complain if hash isn't implemented at the same time as equals - return super(TypedCollectionType, cls).__hash__() - - -def List(sdk_type): - """ - :param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - - class TList(TypedListImpl): - _sub_type = sdk_type - - # TODO: Figure out generics and type-hinting - return TList - - -class ListImpl(_base_sdk_types.FlyteSdkValue, metaclass=CollectionType): - def __len__(self): - return len(self.collection.literals) - - -class TypedListImpl(ListImpl, metaclass=TypedCollectionType): - @classmethod - def from_string(cls, string_value): - """ - Load the list from a JSON formatted string. - :param Text string_value: - :rtype: ListImpl - """ - try: - items = _json.loads(string_value) - except ValueError: - raise _user_exceptions.FlyteTypeException( - _six.text_type, - cls, - additional_msg="String not parseable to json {}".format(string_value), - ) - - if type(items) != list: - raise _user_exceptions.FlyteTypeException( - _six.text_type, - cls, - additional_msg="String is not a list {}".format(string_value), - ) - - # Instead of recursively calling from_string(), we're changing to from_python_std() instead because json - # loading naturally interprets all layers, not just the outer layer. - return cls([cls.sub_type.from_python_std(i) for i in items]) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - if not isinstance(type(other), TypedListImpl): - return False - return cls.sub_type.is_castable_from(other.sub_type) - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - if not isinstance(t_value, list): - raise _user_exceptions.FlyteTypeException(type(t_value), list, t_value) - return cls([cls.sub_type.from_python_std(v) for v in t_value]) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(collection_type=cls.sub_type.to_flyte_literal_type()) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: TypedListImpl - """ - return cls([cls.sub_type.from_flyte_idl(l.to_flyte_idl()) for l in literal_model.collection.literals]) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "List<{}>".format(cls.sub_type.short_class_string()) - - def __init__(self, value): - """ - :param list[flytekit.common.types.base_sdk_types.FlyteSdkValue] value: List value to wrap - """ - super(TypedListImpl, self).__init__(collection=_literals.LiteralCollection(literals=value)) - - def to_python_std(self): - """ - :rtype: list[T] - """ - return [type(self).sub_type.from_flyte_idl(l.to_flyte_idl()).to_python_std() for l in self.collection.literals] - - def short_string(self): - """ - :rtype: Text - """ - num_to_print = 5 - to_print = [v.short_string() for v in self.collection.literals[:num_to_print]] - if len(self.collection.literals) > num_to_print: - to_print.append("...") - return "{}(len={}, [{}])".format( - type(self).short_class_string(), - len(self.collection.literals), - ", ".join(to_print), - ) - - def verbose_string(self): - """ - :rtype: Text - """ - return "{}(\n\tlen={},\n\t[\n\t\t{}\n\t]\n)".format( - type(self).short_class_string(), - len(self.collection.literals), - ",\n\t\t".join("\n\t\t".join(v.verbose_string().splitlines()) for v in self.collection.literals), - ) diff --git a/flytekit/common/types/helpers.py b/flytekit/common/types/helpers.py deleted file mode 100644 index 92294f38fd..0000000000 --- a/flytekit/common/types/helpers.py +++ /dev/null @@ -1,124 +0,0 @@ -import importlib as _importlib - -import six as _six - -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.configuration import sdk as _sdk_config -from flytekit.models import literals as _literal_models - - -class _TypeEngineLoader(object): - _LOADED_ENGINES = None - _LAST_LOADED = None - - @classmethod - def _load_engines(cls): - config = _sdk_config.TYPE_ENGINES.get() - if cls._LOADED_ENGINES is None or config != cls._LAST_LOADED: - cls._LAST_LOADED = config - cls._LOADED_ENGINES = [] - for fqdn in config: - split = fqdn.split(".") - module_path, attr = ".".join(split[:-1]), split[-1] - module = _exception_scopes.user_entry_point(_importlib.import_module)(module_path) - - if not hasattr(module, attr): - raise _user_exceptions.FlyteValueException( - module, - "Failed to load the type engine because the attribute named '{}' could not be found" - "in the module '{}'.".format(attr, module_path), - ) - - engine_impl = getattr(module, attr)() - cls._LOADED_ENGINES.append(engine_impl) - from flytekit.type_engines.default.flyte import FlyteDefaultTypeEngine as _DefaultEngine - - cls._LOADED_ENGINES.append(_DefaultEngine()) - - @classmethod - def iterate_engines_in_order(cls): - """ - :rtype: Generator[flytekit.type_engines.common.TypeEngine] - """ - cls._load_engines() - return iter(cls._LOADED_ENGINES) - - -def python_std_to_sdk_type(t): - """ - :param T t: User input. Should be of the form: Types.Integer, [Types.Integer], {Types.String: Types.Integer}, etc. - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - for e in _TypeEngineLoader.iterate_engines_in_order(): - out = e.python_std_to_sdk_type(t) - if out is not None: - return out - raise _user_exceptions.FlyteValueException(t, "Could not resolve to an SDK type for this value.") - - -def get_sdk_type_from_literal_type(literal_type): - """ - :param flytekit.models.types.LiteralType literal_type: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - for e in _TypeEngineLoader.iterate_engines_in_order(): - out = e.get_sdk_type_from_literal_type(literal_type) - if out is not None: - return out - raise _user_exceptions.FlyteValueException( - literal_type, "Could not resolve to a type implementation for this " "value." - ) - - -def infer_sdk_type_from_literal(literal): - """ - :param flytekit.models.literals.Literal literal: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - for e in _TypeEngineLoader.iterate_engines_in_order(): - out = e.infer_sdk_type_from_literal(literal) - if out is not None: - return out - raise _user_exceptions.FlyteValueException(literal, "Could not resolve to a type implementation for this value.") - - -def get_sdk_value_from_literal(literal, sdk_type=None): - """ - :param flytekit.models.literals.Literal literal: - :param flytekit.models.types.LiteralType sdk_type: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkValue - """ - # The spec states everything must be nullable, so if we receive a null value, swap to the null type behavior. - if sdk_type is None: - sdk_type = infer_sdk_type_from_literal(literal) - return sdk_type.from_flyte_idl(literal.to_flyte_idl()) - - -def unpack_literal_map_to_sdk_object(literal_map, type_map=None): - """ - :param lytekit.models.literals.LiteralMap literal_map: - :param dict[Text, flytekit.common.types.base_sdk_types.FlyteSdkType] type_map: Type map directing unpacking. - :rtype: dict[Text, T] - """ - type_map = type_map or {} - return {k: get_sdk_value_from_literal(v, sdk_type=type_map.get(k, None)) for k, v in literal_map.literals.items()} - - -def unpack_literal_map_to_sdk_python_std(literal_map, type_map=None): - """ - :param flytekit.models.literals.LiteralMap literal_map: Literal map containing values for unpacking. - :param dict[Text, flytekit.common.types.base_sdk_types.FlyteSdkType] type_map: Type map directing unpacking. - :rtype: dict[Text, T] - """ - return {k: v.to_python_std() for k, v in unpack_literal_map_to_sdk_object(literal_map, type_map=type_map).items()} - - -def pack_python_std_map_to_literal_map(std_map, type_map): - """ - :param dict[Text, T] std_map: - :param dict[Text, flytekit.common.types.base_sdk_types.FlyteSdkType] type_map: - :rtype: flytekit.models.literals.LiteralMap - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - return _literal_models.LiteralMap(literals={k: v.from_python_std(std_map[k]) for k, v in _six.iteritems(type_map)}) diff --git a/flytekit/common/types/impl/__init__.py b/flytekit/common/types/impl/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/common/types/impl/blobs.py b/flytekit/common/types/impl/blobs.py deleted file mode 100644 index 45210da62a..0000000000 --- a/flytekit/common/types/impl/blobs.py +++ /dev/null @@ -1,497 +0,0 @@ -import os as _os -import shutil as _shutil -import sys as _sys -import uuid as _uuid - -import six as _six - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import utils as _utils -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literal_models -from flytekit.models.core import types as _core_types - - -class Blob(_literal_models.Blob, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, remote_path, mode="rb", format=None): - """ - :param Text remote_path: Path to location where the Blob should be synced to. - :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. - :param Text format: Format - """ - if "+" in mode or "a" in mode or ("w" in mode and "r" in mode): - raise _user_exceptions.FlyteAssertion("A blob cannot be read and written at the same time") - self._mode = mode - self._local_path = None - self._file = None - super(Blob, self).__init__( - _literal_models.BlobMetadata( - type=_core_types.BlobType(format or "", _core_types.BlobType.BlobDimensionality.SINGLE) - ), - remote_path, - ) - - @classmethod - @_exception_scopes.system_entry_point - def from_python_std(cls, t_value, mode="wb", format=None): - """ - :param T t_value: - :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. - :param Text format: - :rtype: Blob - """ - if isinstance(t_value, (_six.text_type, str)): - if _os.path.isfile(t_value): - blob = cls.create_at_any_location(mode=mode, format=format) - blob._local_path = t_value - blob.upload() - else: - blob = cls.create_at_known_location(t_value, mode=mode, format=format) - return blob - elif isinstance(t_value, cls): - return t_value - else: - raise _user_exceptions.FlyteTypeException( - type(t_value), - {_six.text_type, str, Blob}, - received_value=t_value, - additional_msg="Unable to create Blob from user-provided value.", - ) - - @classmethod - @_exception_scopes.system_entry_point - def from_string(cls, t_value, mode="wb", format=None): - """ - :param T t_value: - :param Text mode: Read or write mode of the object. - :param Text format: - :rtype: Blob - """ - return cls.create_at_known_location(t_value, mode=mode, format=format) - - @classmethod - @_exception_scopes.system_entry_point - def create_at_known_location(cls, known_remote_location, mode="wb", format=None): - """ - :param Text known_remote_location: The location to which to write the object. Usually an s3 path. - :param Text mode: - :param Text format: - :rtype: Blob - """ - return cls(known_remote_location, mode=mode, format=format) - - @classmethod - @_exception_scopes.system_entry_point - def create_at_any_location(cls, mode="wb", format=None): - """ - :param Text mode: - :param Text format: - :rtype: Blob - """ - return cls.create_at_known_location(_data_proxy.Data.get_remote_path(), mode=mode, format=format) - - @classmethod - @_exception_scopes.system_entry_point - def fetch(cls, remote_path, local_path=None, overwrite=False, mode="rb", format=None): - """ - :param Text remote_path: The location from which to fetch the object. Usually an s3 path. - :param Text local_path: [Optional] A local path to which to download the object. If specified, the object - will not be managed and might not be cleaned up by the system upon exiting the context. - :param bool overwrite: If True, objects will be overwritten at the provided local_path in order to fetch this - object. Default is False. - :param Text mode: Read or write mode of the object. - :param Text format: Format the object is in. - :rtype: Blob - """ - blob = cls(remote_path, mode=mode, format=format) - blob.download(local_path=local_path, overwrite=overwrite) - return blob - - @classmethod - def promote_from_model(cls, model, mode="rb"): - """ - :param flytekit.models.literals.Blob model: - :param Text mode: Read or write mode of the object. - :rtype: Blob - """ - return cls(model.uri, format=model.metadata.type.format, mode=mode) - - @property - def local_path(self): - """ - Local filesystem path where the file was downloaded - :rtype: Text - """ - return self._local_path - - @property - def remote_location(self): - """ - Path to where this Blob will be synced. - :rtype: Text - """ - return self.uri - - @property - def mode(self): - """ - The mode string the Blob is associated with. - :rtype: Text - """ - return self._mode - - @_exception_scopes.system_entry_point - def __enter__(self): - """ - :rtype: typing.BinaryIO - """ - if self._file is not None: - raise _user_exceptions.FlyteAssertion("Only one reference can be open to a blob at a time.") - - if self.local_path is None: - if "r" in self.mode: - self.download() - elif "w" in self.mode: - self._generate_local_path() - - self._file = open(self.local_path, self.mode) - return self._file - - @_exception_scopes.system_entry_point - def __exit__(self, exc_type, exc_val, exc_tb): - if self._file is not None and not self._file.closed: - self._file.close() - self._file = None - if "w" in self.mode: - self.upload() - return False - - def _generate_local_path(self): - if _data_proxy.LocalWorkingDirectoryContext.get() is None: - raise _user_exceptions.FlyteAssertion( - "No temporary file system is present. Either call this method from within the " - "context of a task or surround with a 'with LocalTestFileSystem():' block. Or " - "specify a path when calling this function. Note: Cleanup is not automatic when a " - "path is specified." - ) - self._local_path = _data_proxy.LocalWorkingDirectoryContext.get().get_named_tempfile(_uuid.uuid4().hex) - - @_exception_scopes.system_entry_point - def download(self, local_path=None, overwrite=False): - """ - Alternate method, rather than the context manager interface to download the binary file to the local disk. - :param Text local_path: [Optional] If provided, the blob will be downloaded to this path. This will make the - resulting file object unmanaged and it will not be cleaned up by the system upon exiting the context. - :param bool overwrite: If true and local_path is specified, we will download the blob and - overwrite an existing file at that location. Default is False. - """ - if "r" not in self._mode: - raise _user_exceptions.FlyteAssertion("Cannot download a write-only blob!") - - if local_path: - self._local_path = local_path - - if not self.local_path: - self._generate_local_path() - - if overwrite or not _os.path.exists(self.local_path): - # TODO: Introduce system logging - # logging.info("Getting {} -> {}".format(self.remote_location, self.local_path)) - _data_proxy.Data.get_data(self.remote_location, self.local_path, is_multipart=False) - else: - raise _user_exceptions.FlyteAssertion( - "Cannot download blob to a location that already exists when overwrite is not set to True. " - "Attempted download from {} -> {}".format(self.remote_location, self.local_path) - ) - - @_exception_scopes.system_entry_point - def upload(self): - """ - Upload the blob to the remote location - """ - if "w" not in self.mode: - raise _user_exceptions.FlyteAssertion("Cannot upload a read-only blob!") - - elif not self.local_path: - raise _user_exceptions.FlyteAssertion( - "The Blob is not currently backed by a local file and therefore " - "cannot be uploaded. Please write to this Blob before attempting " - "an upload." - ) - else: - # TODO: Introduce system logging - # logging.info("Putting {} -> {}".format(self.local_path, self.remote_location)) - _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=False) - - -class MultiPartBlob(_literal_models.Blob, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, remote_path, mode="rb", format=None): - """ - :param Text remote_path: Path to location where the Blob should be synced to. - :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. - :param Text format: Format of underlying blob pieces. - """ - remote_path = remote_path.strip().rstrip("/") + "/" - super(MultiPartBlob, self).__init__( - _literal_models.BlobMetadata( - type=_core_types.BlobType(format or "", _core_types.BlobType.BlobDimensionality.MULTIPART) - ), - remote_path, - ) - self._is_managed = False - self._blobs = [] - self._directory = None - self._mode = mode - - @classmethod - def promote_from_model(cls, model, mode="rb"): - """ - :param flytekit.models.literals.Blob model: - :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. - :rtype: Blob - """ - return cls(model.uri, format=model.metadata.type.format, mode=mode) - - @classmethod - @_exception_scopes.system_entry_point - def create_at_known_location(cls, known_remote_location, mode="wb", format=None): - """ - :param Text known_remote_location: The location to which to write the object. Usually an s3 path. - :param Text mode: - :param Text format: - :rtype: MultiPartBlob - """ - return cls(known_remote_location, mode=mode, format=format) - - @classmethod - @_exception_scopes.system_entry_point - def create_at_any_location(cls, mode="wb", format=None): - """ - :param Text mode: - :param Text format: - :rtype: MultiPartBlob - """ - return cls.create_at_known_location(_data_proxy.Data.get_remote_path(), mode=mode, format=format) - - @classmethod - @_exception_scopes.system_entry_point - def fetch(cls, remote_path, local_path=None, overwrite=False, mode="rb", format=None): - """ - :param Text remote_path: The location from which to fetch the object. Usually an s3 path. - :param Text local_path: [Optional] A local path to which to download the object. If specified, the object - will not be managed and might not be cleaned up by the system upon exiting the context. - :param bool overwrite: If True, objects will be overwritten at the provided local_path in order to fetch this - object. Default is False. - :param Text mode: Read or write mode of the object. - :param Text format: Format the object is in. - :rtype: MultiPartBlob - """ - blob = cls(remote_path, mode=mode, format=format) - blob.download(local_path=local_path, overwrite=overwrite) - return blob - - @classmethod - @_exception_scopes.system_entry_point - def from_python_std(cls, t_value, mode="wb", format=None): - """ - :param T t_value: - :param Text mode: Read or write mode of the object. - :param Text format: - :rtype: MultiPartBlob - """ - if isinstance(t_value, (str, _six.text_type)): - if _os.path.isdir(t_value): - # TODO: Infer format - blob = cls.create_at_any_location(mode=mode, format=format) - blob._directory = _utils.Directory(t_value) - blob.upload() - else: - blob = cls.create_at_known_location(t_value, mode=mode, format=format) - return blob - elif isinstance(t_value, cls): - return t_value - else: - raise _user_exceptions.FlyteTypeException( - type(t_value), - {str, _six.text_type, MultiPartBlob}, - received_value=t_value, - additional_msg="Unable to create Blob from user-provided value.", - ) - - @classmethod - @_exception_scopes.system_entry_point - def from_string(cls, t_value, mode="wb", format=None): - """ - :param T t_value: - :param Text mode: Read or write mode of the object. - :param Text format: - :rtype: MultiPartBlob - """ - return cls.create_at_known_location(t_value, mode=mode, format=format) - - @_exception_scopes.system_entry_point - def __enter__(self): - """ - :rtype: list[typing.BinaryIO] - """ - if "r" not in self.mode: - raise _user_exceptions.FlyteAssertion("Do not enter context to write to directory. Call create_piece") - - try: - if not self._directory: - if _data_proxy.LocalWorkingDirectoryContext.get() is None: - raise _user_exceptions.FlyteAssertion( - "No temporary file system is present. Either call this method from within the " - "context of a task or surround with a 'with LocalTestFileSystem():' block. Or " - "specify a path when calling this function. Note: Cleanup is not automatic when a " - "path is specified." - ) - self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, - tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, - ) - self._is_managed = True - self._directory.__enter__() - # TODO: Introduce system logging - # logging.info("Copying recursively {} -> {}".format(self.remote_location, self.local_path)) - _data_proxy.Data.get_data(self.remote_location, self.local_path, is_multipart=True) - - # Read the files into blobs in case-insensitive lexicographically ascending orders - self._blobs = [] - file_handles = [] - for local_path in sorted(self._directory.list_dir(), key=lambda x: x.lower()): - b = Blob( - _os.path.join(self.remote_location, _os.path.basename(local_path)), - mode=self.mode, - ) - b._local_path = local_path - file_handles.append(b.__enter__()) - self._blobs.append(b) - - return file_handles - except Exception: - # Exit is idempotent so close partially opened context that way - exc_type, exc_obj, exc_tb = _sys.exc_info() - self.__exit__(exc_type, exc_obj, exc_tb) - raise - - @_exception_scopes.system_entry_point - def __exit__(self, exc_type, exc_val, exc_tb): - for blob in self._blobs: - blob.__exit__(exc_type, exc_val, exc_tb) - self._blobs = [] - if self._is_managed: - self._directory.__exit__(exc_type, exc_val, exc_tb) - self._directory = None - self._is_managed = False - return False - - @property - def local_path(self): - """ - Local filesystem path where the file was downloaded - :rtype: Text - """ - if not self._directory: - return None - return self._directory.name - - @property - def remote_location(self): - """ - Path to where this MultiPartBlob will be synced. - :rtype: Text - """ - return self.uri - - @property - def mode(self): - """ - The mode string the MultiPartBlob is associated with. - :rtype: Text - """ - return self._mode - - @_exception_scopes.system_entry_point - def create_part(self, name=None): - """ - Method which will return a Blob object for writing into a multi-part blob. - - :param Text name: [optional] If name is provided, it is a specific partition name to place as part of the - multipart blob. When we read blobs from a multipart, it will be read in lexicographic order so this can be - used to enforce ordering. If not provided, the name is randomly generated. - :rtype: Blob - """ - if "w" not in self.mode: - raise _user_exceptions.FlyteAssertion("Cannot create a blob in a read-only multipart blob") - if name is None: - name = _uuid.uuid4().hex - if ":" in name or "/" in name: - raise _user_exceptions.FlyteAssertion( - name, - "Cannot create a part of a multi-part object with ':' or '/' in the name.", - ) - return Blob.create_at_known_location( - _os.path.join(self.remote_location, name), - mode=self.mode, - format=self.metadata.type.format, - ) - - @_exception_scopes.system_entry_point - def download(self, local_path=None, overwrite=False): - """ - Forces the download of the remote multi-part blob to the local machine. - :param Text local_path: [Optional] If provided, the blob pieces will be downloaded to this path. This will - make the resulting file objects unmanaged and it will not be cleaned up by the system upon exiting the - context. - :param bool overwrite: If true and local_path is specified, we will download the blob pieces and - overwrite any existing files at that location. Default is False. - """ - if "r" not in self.mode: - raise _user_exceptions.FlyteAssertion("Cannot download a write-only object!") - - if local_path: - self._is_managed = False - elif _data_proxy.LocalWorkingDirectoryContext.get() is None: - raise _user_exceptions.FlyteAssertion( - "No temporary file system is present. Either call this method from within the " - "context of a task or surround with a 'with LocalTestFileSystem():' block. Or " - "specify a path when calling this function. Note: Cleanup is not automatic when a " - "path is specified." - ) - else: - local_path = _data_proxy.LocalWorkingDirectoryContext.get().get_named_tempfile(_uuid.uuid4().hex) - - path_exists = _os.path.exists(local_path) - if not path_exists or overwrite: - if path_exists: - _shutil.rmtree(local_path) - _os.makedirs(local_path) - self._directory = _utils.Directory(local_path) - _data_proxy.Data.get_data(self.remote_location, self.local_path, is_multipart=True) - else: - raise _user_exceptions.FlyteAssertion( - "Cannot download multi-part blob to a location that already exists when overwrite is not set to True. " - "Attempted download from {} -> {}".format(self.remote_location, self.local_path) - ) - - @_exception_scopes.system_entry_point - def upload(self): - """ - Upload the multi-part blob to the remote location - """ - if "w" not in self.mode: - raise _user_exceptions.FlyteAssertion("Cannot upload a read-only multi-part blob!") - - elif not self.local_path: - raise _user_exceptions.FlyteAssertion( - "The multi-part blob is not currently backed by a local directoru " - "and therefore cannot be uploaded. Please write to this before " - "attempting an upload." - ) - else: - # TODO: Introduce system logging - # logging.info("Putting {} -> {}".format(self.local_path, self.remote_location)) - _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=True) diff --git a/flytekit/common/types/impl/schema.py b/flytekit/common/types/impl/schema.py deleted file mode 100644 index 04e4109d1c..0000000000 --- a/flytekit/common/types/impl/schema.py +++ /dev/null @@ -1,995 +0,0 @@ -import collections as _collections -import os as _os -import uuid as _uuid - -import six as _six - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import utils as _utils -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.common.types import helpers as _helpers -from flytekit.common.types import primitives as _primitives -from flytekit.common.types.impl import blobs as _blob_impl -from flytekit.configuration import sdk as _sdk_config -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literal_models -from flytekit.models import types as _type_models -from flytekit.plugins import numpy as _np -from flytekit.plugins import pandas as _pd - -# Note: For now, this is only for basic type-checking. We need not differentiate between TINYINT, BIGINT, -# and INT or DOUBLE and FLOAT, VARCHAR and STRING, etc. as we will unpack into appropriate Python -# objects anyway. If we work on managed tables, these more specific type specifications might become necessary. -_SUPPORTED_LITERAL_TYPE_TO_PANDAS_TYPES = None - - -def get_supported_literal_types_to_pandas_types(): - global _SUPPORTED_LITERAL_TYPE_TO_PANDAS_TYPES - if _SUPPORTED_LITERAL_TYPE_TO_PANDAS_TYPES is None: - _SUPPORTED_LITERAL_TYPE_TO_PANDAS_TYPES = { - _primitives.Integer.to_flyte_literal_type(): {_np.int32, _np.int64, _np.uint32, _np.uint64}, - _primitives.Float.to_flyte_literal_type(): {_np.float32, _np.float64}, - _primitives.Boolean.to_flyte_literal_type(): {_np.bool}, - _primitives.Datetime.to_flyte_literal_type(): {_np.datetime64}, - _primitives.Timedelta.to_flyte_literal_type(): {_np.timedelta64}, - _primitives.String.to_flyte_literal_type(): {_np.object_, _np.str_, _np.string_}, - } - return _SUPPORTED_LITERAL_TYPE_TO_PANDAS_TYPES - - -_ALLOWED_PARTITION_TYPES = {str, int} - -# Hive currently has limitations where column headers are not stored when writing to an overwrite directory. There is -# an open proposal (https://issues.apache.org/jira/browse/HIVE-12860) to improve this. Until then, we have this -# work-around where we create an external table with the appropriate schema and write the data to our desired -# location. The issue here is that the table information in the meta-store might not get cleaned up during a partial -# failure. -_HIVE_QUERY_FORMATTER = """ - {stage_query_str} - - CREATE TEMPORARY TABLE {table}_tmp AS {query_str}; - CREATE EXTERNAL TABLE {table} LIKE {table}_tmp STORED AS PARQUET; - ALTER TABLE {table} SET LOCATION '{url}'; - - INSERT OVERWRITE TABLE {table} - SELECT - {columnar_query} - FROM {table}_tmp; - DROP TABLE {table}; - """ - -# Once https://issues.apache.org/jira/browse/HIVE-12860 is resolved. We will prefer the following syntax because it -# guarantees cleanup on partial failures. -_HIVE_QUERY_FORMATTER_V2 = """ - CREATE TEMPORARY TABLE {table} AS {query_str}; - - INSERT OVERWRITE DIRECTORY '{url}' STORED AS PARQUET - SELECT {columnar_query} - FROM {table}; - """ - -# Set location in both parts of this query so in case of a partial failure, we will always have some data backing a -# partition. -_WRITE_HIVE_PARTITION_QUERY_FORMATTER = """ - ALTER TABLE {write_table} ADD IF NOT EXISTS {partition_string} LOCATION '{url}'; - - ALTER TABLE {write_table} {partition_string} SET LOCATION '{url}'; - """ - - -def _format_insert_partition_query(table_name, partition_string, remote_location): - table_pieces = table_name.split(".") - if len(table_pieces) > 1: - # Hive shell commands don't allow us to alter tables and select databases in the table specification. So - # we split the table name and use the 'use' command to choose the correct database. - prefix = "use {};\n".format(table_pieces[0]) - table_name = ".".join(table_pieces[1:]) - else: - prefix = "" - - return prefix + _WRITE_HIVE_PARTITION_QUERY_FORMATTER.format( - write_table=table_name, partition_string=partition_string, url=remote_location - ) - - -class _SchemaIO(object): - def __init__(self, schema_instance, local_dir, mode): - """ - :param Schema schema_instance: - :param flytekit.common.utils.Directory local_dir: - :param Text mode: - """ - self._schema = schema_instance - self._local_dir = local_dir - self._chunks = [] - self._index = 0 - self._mode = mode - - def _access_guard(self): - if not self._schema: - raise _user_exceptions.FlyteAssertion( - "Schema IO object has already been closed. Cannot access chunk_count property." - ) - - @_exception_scopes.system_entry_point - def iter_chunks(self, *args, **kwargs): - raise _user_exceptions.FlyteAssertion("{} is write only.".format(self._schema)) - - @_exception_scopes.system_entry_point - def read(self, *args, **kwargs): - raise _user_exceptions.FlyteAssertion("{} is write only.".format(self._schema)) - - @_exception_scopes.system_entry_point - def write(self, *args, **kwargs): - raise _user_exceptions.FlyteAssertion("{} is read only.".format(self._schema)) - - @_exception_scopes.system_entry_point - def close(self): - self._schema = None - self._local_dir = None - self._chunks = None - self._index = 0 - - @property - @_exception_scopes.system_entry_point - def chunk_count(self): - self._access_guard() - return len(self._chunks) - - @_exception_scopes.system_entry_point - def seek(self, index): - self._access_guard() - if index < 0 or index > self.chunk_count: - raise _user_exceptions.FlyteValueException( - index, - "Attempting to seek to a chunk that is out of range. Allowed range is [0, {}]".format(self.chunk_count), - ) - self._index = index - - @_exception_scopes.system_entry_point - def tell(self): - return self._index - - def __repr__(self): - return "{mode} IO Object for {type} @ {location}".format( - type=self._schema.type, location=self._schema.remote_prefix, mode=self._mode - ) - - -class _SchemaReader(_SchemaIO): - def __init__(self, schema_instance, local_dir): - """ - :param Schema schema_instance: - :param flytekit.common.utils.Directory local_dir: - """ - super(_SchemaReader, self).__init__(schema_instance, local_dir, "Read-Only") - self.reset_chunks() - - @_exception_scopes.system_entry_point - def reset_chunks(self): - self._chunks = sorted(self._local_dir.list_dir()) - - @_exception_scopes.system_entry_point - def iter_chunks(self, columns=None, **kwargs): - self._access_guard() - while self._index < len(self._chunks): - chunk = self.read(columns=columns, concat=False, **kwargs) - if chunk is not None: - yield chunk - - @staticmethod - def _read_parquet_with_type_promotion_override(chunk, columns, parquet_engine): - """ - This wrapper function of pd.read_parquet() is a hack intended to fix the type promotion problem - when using fastparquet as the underlying parquet engine. - - When using fastparquet, boolean columns containing None values will be promoted to float16 columns. - This becomes problematic when users want to write the dataframe back into parquet - file because float16 (halffloat) is not a supported type in parquet spec. In this function, we detect - such columns and do override the type promotion. - """ - df = None - - if parquet_engine == "fastparquet": - import fastparquet.thrift_structures as _ts - from fastparquet import ParquetFile as _ParquetFile - - # https://github.com/dask/fastparquet/issues/414#issuecomment-478983811 - df = _pd.read_parquet(chunk, columns=columns, engine=parquet_engine, index=False) - df_column_types = df.dtypes - pf = _ParquetFile(chunk) - schema_column_dtypes = {l.name: l.type for l in list(pf.schema.schema_elements)} - - for idx in df_column_types[df_column_types == "float16"].index.tolist(): - # A hacky way to get the string representations of the column types of a parquet schema - # Reference: - # https://github.com/dask/fastparquet/blob/f4ecc67f50e7bf98b2d0099c9589c615ea4b06aa/fastparquet/schema.py - if _ts.parquet_thrift.Type._VALUES_TO_NAMES[schema_column_dtypes[idx]] == "BOOLEAN": - df[idx] = df[idx].astype("object") - df[idx].replace({0: False, 1: True, _pd.np.nan: None}, inplace=True) - - else: - df = _pd.read_parquet(chunk, columns=columns, engine=parquet_engine) - - return df - - @_exception_scopes.system_entry_point - def read(self, columns=None, concat=False, truncate_extra_columns=True, **kwargs): - """ - When this function is called, one chunk will be read and received as a Pandas data frame. Once all chunks - have been read, this function will return None. - - :param list[Text] columns: A list of columns to read. They must be a subset of the columns - defined for the Schema object. If specified, truncate_extra_columns must be True. - :param bool concat: If true, the entire object will be returned in one large data frame. - :param bool truncate_extra_columns: If true, only columns from the underlying parquet file will be read if - they are specified as columns in the schema object (except for empty schemas which will read all columns - regardless). If false, if there are additional columns in the underlying parquet file, they will also be - read. - :rtype: pandas.DataFrame - """ - if columns is not None and truncate_extra_columns is False: - raise _user_exceptions.FlyteAssertion( - "When reading a schema object, it is not possible to both specify a set of columns to read and " - "additionally not truncate_extra_columns. Either columns must not be specified or " - "truncate_extra_columns must be set to True (or not specified)." - ) - - self._access_guard() - - parquet_engine = _sdk_config.PARQUET_ENGINE.get() - if parquet_engine not in {"fastparquet", "pyarrow"}: - raise _user_exceptions.FlyteAssertion( - "environment variable parquet_engine must be one of 'pyarrow', 'fastparquet', or be unset" - ) - - df_out = None - if not columns: - columns = list(self._schema.type.sdk_columns.keys()) - - if len(columns) == 0 or truncate_extra_columns is False: - columns = None - - if concat: - frames = [ - # A hacky hack - # TODO: follow up the issue opened in the fastparquet repo for a more general fix - # issue URL: - _SchemaReader._read_parquet_with_type_promotion_override( - chunk=chunk, columns=columns, parquet_engine=parquet_engine - ) - # _pd.read_parquet(chunk, columns=columns, engine=parquet_engine) - for chunk in self._chunks[self._index :] - if _os.path.getsize(chunk) > 0 - ] - if len(frames) == 1: - df_out = frames[0] - elif len(frames) > 1: - df_out = _pd.concat(frames, copy=True) - self._index = len(self._chunks) - else: - while self._index < len(self._chunks) and df_out is None: - # Skip empty chunks so the user appears to have a continuous stream of data. - if _os.path.getsize(self._chunks[self._index]) > 0: - df_out = _SchemaReader._read_parquet_with_type_promotion_override( - chunk=self._chunks[self._index], columns=columns, parquet_engine=parquet_engine, **kwargs - ) - self._index += 1 - - if df_out is not None: - self._schema.compare_dataframe_to_schema(df_out, read=True, column_subset=columns) - - # Make sure the columns are renamed to exactly what the user specifies. This prevents unexpected - # unicode v. string mismatches. Also, if a schema is mapped with strict_names=False, the input might - # have totally different names. - user_columns = columns or _six.iterkeys(self._schema.type.sdk_columns) - # User-specified columns may or may not be unicode - # Since, in python 2, dictionary does a transparent translation between unicode and str for the key, - # (https://stackoverflow.com/a/24532329) - # we use this characteristic to create a trivial lookup dictionary, to make sure we can use either - # unicode or str to lookup, but get back whatever type the user actually used - user_column_dict = {c: c for c in user_columns} - if len(self._schema.type.columns) > 0: - # Avoid using pandas.DataFrame.rename() as this function incurs significant memory overhead - df_out.columns = [ - user_column_dict[col] if col in user_columns else col for col in df_out.columns.values - ] - return df_out - - -class _SchemaWriter(_SchemaIO): - def __init__(self, schema_instance, local_dir): - """ - :param Schema schema_instance: - :param flytekit.common.utils.Directory local_dir: - :param Text mode: - """ - super(_SchemaWriter, self).__init__(schema_instance, local_dir, "Write-Only") - - @_exception_scopes.system_entry_point - def close(self): - """ - Closes the writing IO context and uploads data to s3. - """ - try: - # TODO: Introduce system logging - # logging.info("Copying recursively {} -> {}".format(self._local_dir.name, self._schema.remote_prefix)) - _data_proxy.Data.put_data(self._local_dir.name, self._schema.remote_prefix, is_multipart=True) - finally: - super(_SchemaWriter, self).close() - - @_exception_scopes.system_entry_point - def write(self, data_frame, coerce_timestamps="us", allow_truncated_timestamps=False): - """ - Writes data frame as a chunk to the local directory owned by the Schema object. Will later be uploaded to s3. - - :param pandas.DataFrame data_frame: data frame to write as parquet - :param Text coerce_timestamps: format to store timestamp in parquet. 'us', 'ms', 's' are allowed values. - Note: if your timestamps will lose data due to the coercion, your write will fail! Nanoseconds are - problematic in the Parquet format and will not work. See allow_truncated_timestamps. - :param bool allow_truncated_timestamps: default False. Allow truncation when coercing timestamps to a coarser - resolution. - """ - self._access_guard() - if not isinstance(data_frame, _pd.DataFrame): - raise _user_exceptions.FlyteTypeException( - expected_type=_pd.DataFrame, - received_type=type(data_frame), - received_value=data_frame, - additional_msg="Only pandas DataFrame objects can be written to a Schema object", - ) - - self._schema.compare_dataframe_to_schema(data_frame) - all_columns = list(data_frame.columns.values) - - # Convert all columns to unicode as pyarrow's parquet reader can not handle mixed strings and unicode. - # Since columns from Hive are returned as unicode, if a user wants to add a column to a dataframe returned from - # Hive, then output the new data, the user would have to provide a unicode column name which is unnatural. - unicode_columns = [_six.text_type(col) for col in all_columns] - data_frame.columns = unicode_columns - try: - filename = self._local_dir.get_named_tempfile(_os.path.join(str(self._index).zfill(6))) - data_frame.to_parquet( - filename, - coerce_timestamps=coerce_timestamps, - allow_truncated_timestamps=allow_truncated_timestamps, - ) - if self._index == len(self._chunks): - self._chunks.append(filename) - self._index += 1 - finally: - # Return to old names to prevent odd behavior with user. - data_frame.columns = unicode_columns - - -class _SchemaBackingMpBlob(_blob_impl.MultiPartBlob): - @property - def directory(self): - """ - :rtype: flytekit.common.utils.Directory - """ - return self._directory - - def __enter__(self): - if not self.local_path: - if _data_proxy.LocalWorkingDirectoryContext.get() is None: - raise _user_exceptions.FlyteAssertion( - "No temporary file system is present. Either call this method from within the " - "context of a task or surround with a 'with LocalTestFileSystem():' block. Or " - "specify a path when calling this function." - ) - self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, - tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, - ) - self._is_managed = True - self._directory.__enter__() - - if "r" in self.mode: - _data_proxy.Data.get_data(self.remote_location, self.local_path, is_multipart=True) - - def __exit__(self, exc_type, exc_val, exc_tb): - if "w" in self.mode: - _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=True) - return super(_SchemaBackingMpBlob, self).__exit__(exc_type, exc_val, exc_tb) - - -class SchemaType(_type_models.SchemaType, metaclass=_sdk_bases.ExtendedSdkType): - _LITERAL_TYPE_TO_PROTO_ENUM = { - _primitives.Integer.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER, - _primitives.Float.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT, - _primitives.Boolean.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, - _primitives.Datetime.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME, - _primitives.Timedelta.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION, - _primitives.String.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING, - } - - def __init__(self, columns=None): - super(SchemaType, self).__init__(None) - self._set_columns(columns or []) - - @property - def sdk_columns(self): - """ - This is an ordered dictionary so iterating over it will be in the order columns were specified in the - constructor. - :rtype: dict[Text, flytekit.common.types.base_sdk_types.FlyteSdkType] - """ - return self._sdk_columns - - @property - def columns(self): - """ - :rtype: list[flytekit.models.types.SchemaType.SchemaColumn] - """ - return [ - _type_models.SchemaType.SchemaColumn(n, type(self)._LITERAL_TYPE_TO_PROTO_ENUM[v.to_flyte_literal_type()]) - for n, v in _six.iteritems(self.sdk_columns) - ] - - @classmethod - def promote_from_model(cls, model): - """ - :param flytekit.models.types.SchemaType model: - :rtype: SchemaType - """ - _PROTO_ENUM_TO_SDK_TYPE = { - _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER: _helpers.get_sdk_type_from_literal_type( - _primitives.Integer.to_flyte_literal_type() - ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT: _helpers.get_sdk_type_from_literal_type( - _primitives.Float.to_flyte_literal_type() - ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN: _helpers.get_sdk_type_from_literal_type( - _primitives.Boolean.to_flyte_literal_type() - ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME: _helpers.get_sdk_type_from_literal_type( - _primitives.Datetime.to_flyte_literal_type() - ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION: _helpers.get_sdk_type_from_literal_type( - _primitives.Timedelta.to_flyte_literal_type() - ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING: _helpers.get_sdk_type_from_literal_type( - _primitives.String.to_flyte_literal_type() - ), - } - return cls([(c.name, _PROTO_ENUM_TO_SDK_TYPE[c.type]) for c in model.columns]) - - def _set_columns(self, columns): - names_seen = set() - for column in columns: - if not isinstance(column, tuple): - raise _user_exceptions.FlyteValueException( - column, - "When specifying a Schema type with a known set of columns. Each column must be " - "specified as a tuple in the form ('name', type).", - ) - if len(column) != 2: - raise _user_exceptions.FlyteValueException( - column, - "When specifying a Schema type with a known set of columns. Each column must be " - "specified as a tuple in the form ('name', type).", - ) - name, sdk_type = column - sdk_type = _helpers.python_std_to_sdk_type(sdk_type) - - if not isinstance(name, (str, _six.text_type)): - additional_msg = ( - "When specifying a Schema type with a known set of columns, the first element in" - " each tuple must be text." - ) - raise _user_exceptions.FlyteTypeException( - received_type=type(name), - received_value=name, - expected_type={str, _six.text_type}, - additional_msg=additional_msg, - ) - - if ( - not isinstance(sdk_type, _base_sdk_types.FlyteSdkType) - or sdk_type.to_flyte_literal_type() not in get_supported_literal_types_to_pandas_types() - ): - additional_msg = ( - "When specifying a Schema type with a known set of columns, the second element of " - "each tuple must be a supported type. Failed for column: {name}".format(name=name) - ) - raise _user_exceptions.FlyteTypeException( - expected_type=list(get_supported_literal_types_to_pandas_types().keys()), - received_type=sdk_type, - additional_msg=additional_msg, - ) - - if name in names_seen: - raise ValueError( - "The column name {name} was specified multiple times when instantiating the " - "Schema.".format(name=name) - ) - names_seen.add(name) - - self._sdk_columns = _collections.OrderedDict(columns) - - -class Schema(_literal_models.Schema, metaclass=_sdk_bases.ExtendedSdkType): - def __init__(self, remote_path, mode="rb", schema_type=None): - """ - :param Text remote_path: - :param Text mode: - :param SchemaType schema_type: [Optional] If specified, the schema will be forced to conform to this type. If - not specified, the schema will be considered generic. - """ - self._mp_blob = _SchemaBackingMpBlob(remote_path, mode=mode) - super(Schema, self).__init__(self._mp_blob.uri, schema_type or SchemaType()) - self._io_object = None - - @classmethod - def promote_from_model(cls, model): - """ - :param flytekit.models.literals.Schema model: - :rtype: Schema - """ - return cls(model.uri, schema_type=SchemaType.promote_from_model(model.type)) - - @classmethod - @_exception_scopes.system_entry_point - def create_at_known_location(cls, known_remote_location, mode="wb", schema_type=None): - """ - :param Text known_remote_location: The location to which to write the object. Usually an s3 path. - :param Text mode: - :param SchemaType schema_type: [Optional] If specified, the schema will be forced to conform to this type. If - not specified, the schema will be considered generic. - :rtype: Schema - """ - return cls(known_remote_location, mode=mode, schema_type=schema_type) - - @classmethod - @_exception_scopes.system_entry_point - def create_at_any_location(cls, mode="wb", schema_type=None): - """ - :param Text mode: - :param SchemaType schema_type: [Optional] If specified, the schema will be forced to conform to this type. If - not specified, the schema will be considered generic. - :rtype: Schema - """ - return cls.create_at_known_location(_data_proxy.Data.get_remote_path(), mode=mode, schema_type=schema_type) - - @classmethod - @_exception_scopes.system_entry_point - def fetch(cls, remote_path, local_path=None, overwrite=False, mode="rb", schema_type=None): - """ - :param Text remote_path: The location from which to fetch the object. Usually an s3 path. - :param Text local_path: [Optional] A local path to which to download the object. If specified, the object - will not be managed and might not be cleaned up by the system upon exiting the context. - :param bool overwrite: If True, objects will be overwritten at the provided local_path in order to fetch this - object. Default is False. - :param Text mode: Read or write mode of the object. - :param SchemaType schema_type: [Optional] If specified, the schema will be forced to conform to this type. If - not specified, the schema will be considered generic. - :rtype: Schema - """ - schema = cls(remote_path, mode=mode, schema_type=schema_type) - schema.download(local_path=local_path, overwrite=overwrite) - return schema - - @classmethod - @_exception_scopes.system_entry_point - def from_python_std(cls, t_value, schema_type=None): - """ - :param T t_value: - :param SchemaType schema_type: [Optional] If specified, we will ensure - :rtype: Schema - """ - if isinstance(t_value, (str, _six.text_type)): - if _os.path.isdir(t_value): - schema = cls.create_at_any_location(schema_type=schema_type) - schema.multipart_blob._directory = _utils.Directory(t_value) - schema.upload() - else: - schema = cls.create_at_known_location(t_value, schema_type=schema_type) - return schema - elif isinstance(t_value, cls): - return t_value - elif isinstance(t_value, _pd.DataFrame): - # Accepts a pandas dataframe and converts to a Schema object - o = cls.create_at_any_location(schema_type=schema_type) - with o as w: - w.write(t_value) - return o - elif isinstance(t_value, list): - # Accepts a list of pandas dataframe and converts to a Schema object - o = cls.create_at_any_location(schema_type=schema_type) - with o as w: - for x in t_value: - if isinstance(x, _pd.DataFrame): - w.write(x) - else: - raise _user_exceptions.FlyteTypeException( - type(t_value), - {str, _six.text_type, Schema}, - received_value=x, - additional_msg="A Schema object can only be create from a pandas DataFrame or a list of pandas DataFrame.", - ) - return o - else: - raise _user_exceptions.FlyteTypeException( - type(t_value), - {str, _six.text_type, Schema}, - received_value=t_value, - additional_msg="Unable to create Schema from user-provided value.", - ) - - @classmethod - def from_string(cls, string_value, schema_type=None): - """ - :param Text string_value: - :param SchemaType schema_type: - :rtype: Schema - """ - if not string_value: - _user_exceptions.FlyteValueException(string_value, "Cannot create a Schema from an empty path") - return cls.create_at_known_location(string_value, schema_type=schema_type) - - @classmethod - @_exception_scopes.system_entry_point - def create_from_hive_query( - cls, - select_query, - stage_query=None, - schema_to_table_name_map=None, - schema_type=None, - known_location=None, - ): - """ - Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed - schema object. - - :param Text select_query: Query for selecting data from Hive - :param Text stage_query: Query for building temporary tables on Hive. - Runs before the select query. Temporary tables are supported but CTEs are not supported. - :param Dict[Text, Text] schema_to_table_name_map: A map of column names in the schema to the column names - returned from the select query - :param Text known_location: create the schema object at a known s3 location. - :param SchemaType schema_type: [Optional] If specified, the schema will be forced to conform to this type. If - not specified, the schema will be considered generic. - :return: Schema, Text - """ - schema_object = cls( - known_location or _data_proxy.Data.get_remote_directory(), - mode="wb", - schema_type=schema_type, - ) - - if len(schema_object.type.sdk_columns) > 0: - identity_dict = {n: n for n in _six.iterkeys(schema_object.type.sdk_columns)} - identity_dict.update(schema_to_table_name_map or {}) - schema_to_table_name_map = identity_dict - - columnar_clauses = [] - for name, sdk_type in _six.iteritems(schema_object.type.sdk_columns): - if sdk_type == _primitives.Float: - columnar_clauses.append( - "CAST({table_column_name} as double) {schema_name}".format( - table_column_name=schema_to_table_name_map[name], - schema_name=name, - ) - ) - else: - columnar_clauses.append( - "{table_column_name} as {schema_name}".format( - table_column_name=schema_to_table_name_map[name], - schema_name=name, - ) - ) - columnar_query = ",\n\t\t".join(columnar_clauses) - else: - columnar_query = "*" - - stage_query_str = _six.text_type(stage_query or "") - # the stage query should always end with a semicolon - stage_query_str = stage_query_str if stage_query_str.endswith(";") else (stage_query_str + ";") - query = _HIVE_QUERY_FORMATTER.format( - url=schema_object.remote_location, - stage_query_str=stage_query_str, - query_str=select_query.strip().strip(";"), - columnar_query=columnar_query, - table=_uuid.uuid4().hex, - ) - return schema_object, query - - @property - def local_path(self): - """ - Local filesystem path where the file was downloaded - :rtype: Text - """ - return self._mp_blob.local_path - - @property - def remote_location(self): - """ - Path to where this MultiPartBlob will be synced. This is needed for reverse compatibility. - :rtype: Text - """ - return self.uri - - @property - def remote_prefix(self): - """ - Path to where this MultiPartBlob will be synced. This is needed for reverse compatibility. - :rtype: Text - """ - return self.uri - - @property - def uri(self): - """ - Path to where this MultiPartBlob will be synced. - :rtype: Text - """ - return self.multipart_blob.uri - - @property - def mode(self): - """ - The mode string the MultiPartBlob is associated with. - :rtype: Text - """ - return self._mp_blob.mode - - @property - def type(self): - """ - The schema type definition associated with this object. - :rtype: SchemaType - """ - return self._type - - @property - def multipart_blob(self): - """ - :rtype: flytekit.common.types.impl.blobs.MultiPartBlob - """ - return self._mp_blob - - @_exception_scopes.system_entry_point - def __enter__(self): - """ - :rtype: _SchemaIO - """ - if self._io_object is not None: - raise _user_exceptions.FlyteAssertion( - "The context of a schema can only be entered once at a time. Make sure the previous " - "'with' block has been exited." - ) - - self._mp_blob.__enter__() - if "r" in self.mode: - self._io_object = _SchemaReader(self, self.multipart_blob.directory) - else: - self._io_object = _SchemaWriter(self, self.multipart_blob.directory) - return self._io_object - - @_exception_scopes.system_entry_point - def __exit__(self, exc_type, exc_val, exc_tb): - self._io_object = None - return self._mp_blob.__exit__(exc_type, exc_val, exc_tb) - - def __repr__(self): - return "Schema({columns}) @ {location} ({mode})".format( - columns=self.type.columns, - location=self.remote_prefix, - mode="read-only" if "r" in self.mode else "write-only", - ) - - @_exception_scopes.system_entry_point - def download(self, local_path=None, overwrite=False): - """ - :param Text local_path: [Optional] A local path to which to download the object. If specified, the object - will not be managed and might not be cleaned up by the system upon exiting the context. - :param bool overwrite: If True, objects will be overwritten at the provided local_path in order to fetch this - object. Default is False. - :rtype: Schema - """ - self.multipart_blob.download(local_path=local_path, overwrite=overwrite) - - @_exception_scopes.system_entry_point - def get_write_partition_to_hive_table_query( - self, - table_name, - partitions=None, - schema_to_table_name_map=None, - partitions_in_table=False, - append_to_partition=False, - ): - """ - Returns a Hive query string that will update the metatable to point to the data as the new partition. - - :param Text table_name: - :param dict[Text, T] partitions: A dictionary mapping table partition key names to the values matching this - partition. - :param dict[Text, Text] schema_to_table_name_map: Mapping of names in current schema to table in which it is - being inserted. Currently not supported. Must be None. - :param bool partitions_in_table: Whether or not the partition columns exist in the data being submitted. - Currently not supported. Must be false - :param bool append_to_partition: Whether or not to append new values to a partition. Currently not supported. - :return: Text - """ - partition_string = "" - where_string = "" - identity_dict = {n: n for n in _six.iterkeys(self.type.sdk_columns)} - identity_dict.update(schema_to_table_name_map or {}) - schema_to_table_name_map = identity_dict - table_to_schema_name_map = {v: k for k, v in _six.iteritems(schema_to_table_name_map)} - - if partitions: - partition_conditions = [] - for partition_name, partition_value in _six.iteritems(partitions): - if not isinstance(partition_name, (str, _six.text_type)): - raise _user_exceptions.FlyteTypeException( - expected_type={str, _six.text_type}, - received_type=type(partition_name), - received_value=partition_name, - additional_msg="All partition names must be type str.", - ) - if type(partition_value) not in _ALLOWED_PARTITION_TYPES: - raise _user_exceptions.FlyteTypeException( - expected_type=_ALLOWED_PARTITION_TYPES, - received_type=type(partition_value), - received_value=partition_value, - additional_msg="Partition {name} has an unsupported type.".format(name=partition_name), - ) - - # We need the string to be quoted in the query, so let's take repr of it. - if isinstance(partition_value, (str, _six.text_type)): - partition_value = repr(partition_value) - partition_conditions.append( - "{partition_name} = {partition_value}".format( - partition_name=partition_name, partition_value=partition_value - ) - ) - partition_formatter = "PARTITION (\n\t{conditions}\n)" - partition_string = partition_formatter.format(conditions=",\n\t".join(partition_conditions)) - - if partitions_in_table and partitions: - where_clauses = [] - for partition_name, partition_value in partitions: - where_clauses.append( - "\n\t\t{schema_name} = {value_str} AND ".format( - schema_name=table_to_schema_name_map[partition_name], - value_str=partition_value, - ) - ) - where_string = "WHERE\n\t\t{where_clauses}".format(where_clauses=" AND\n\t\t".join(where_clauses)) - - if where_string or partitions_in_table: - raise _user_exceptions.FlyteAssertion( - "Currently, the partition values should not be present in the schema pushed to Hive." - ) - if append_to_partition: - raise _user_exceptions.FlyteAssertion( - "Currently, partitions can only be overwritten, they cannot be appended." - ) - if not partitions: - raise _user_exceptions.FlyteAssertion( - "Currently, partition values MUST be specified for writing to a table." - ) - - return _format_insert_partition_query( - remote_location=self.remote_location, - table_name=table_name, - partition_string=partition_string, - ) - - def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False): - """ - Do necessary type checking of a pandas data frame. Raise exception if it doesn't match. - :param pandas.DateFrame data_frame: data frame to type check - :param list[Text] column_subset: - :param bool read: Used to alter error message for more clarity. - """ - all_columns = list(data_frame.columns.values) - schema_column_names = list(self.type.sdk_columns.keys()) - - # Skip checking if we have a generic schema type (no specified columns) - if not schema_column_names: - return - - # If we specify a subset of columns, ensure they all exist and then only take those columns - if column_subset is not None: - schema_column_names = [] - failed_columns = [] - for column in column_subset: - if column not in self.type.sdk_columns: - failed_columns.append(column) - else: - schema_column_names.append(column) - - if len(failed_columns) > 0: - additional_msg = "" - raise _user_exceptions.FlyteAssertion( - "{} was/where requested but could not be found in the schema: {}.{}".format( - failed_columns, self.type.sdk_columns, additional_msg - ) - ) - - if not all(c in all_columns for c in schema_column_names): - raise _user_exceptions.FlyteTypeException( - expected_type=self.type.sdk_columns, - received_type=data_frame.columns, - additional_msg="Mismatch between the data frame's column names {} and schema's column names {} " - "with strict_names=True.".format(all_columns, schema_column_names), - ) - - # This only iterates if the Schema has specified columns. - for name in schema_column_names: - literal_type = self.type.sdk_columns[name].to_flyte_literal_type() - dtype = data_frame[name].dtype - - # TODO np.issubdtype is deprecated. Replace it - if all( - not _np.issubdtype(dtype, allowed_type) - for allowed_type in get_supported_literal_types_to_pandas_types()[literal_type] - ): - if read: - read_or_write_msg = "read data frame object from schema" - else: - read_or_write_msg = "write data frame object to schema" - additional_msg = ( - "Cannot {read_write} because the types do not match. Column " - "'{name}' did not pass type checking. Note: If your " - "column contains null values, the types might not transition as expected between parquet and " - "pandas. For more information, see: " - "http://arrow.apache.org/docs/python/pandas.html#arrow-pandas-conversion".format( - read_write=read_or_write_msg, name=name - ) - ) - raise _user_exceptions.FlyteTypeException( - expected_type=get_supported_literal_types_to_pandas_types()[literal_type], - received_type=dtype, - additional_msg=additional_msg, - ) - - def cast_to(self, other_type): - """ - :param SchemaType other_type: - :rtype: Schema - """ - if len(other_type.sdk_columns) > 0: - for k, v in _six.iteritems(other_type.sdk_columns): - if k not in self.type.sdk_columns: - raise _user_exceptions.FlyteTypeException( - self.type, - other_type, - additional_msg="Cannot cast because a required column '{}' was not found.".format(k), - received_value=self, - ) - if ( - not isinstance(v, _base_sdk_types.FlyteSdkType) - or v.to_flyte_literal_type() != self.type.sdk_columns[k].to_flyte_literal_type() - ): - raise _user_exceptions.FlyteTypeException( - self.type.sdk_columns[k], - v, - additional_msg="Cannot cast because the column type for column '{}' does not match.".format(k), - ) - return Schema(self.remote_location, mode=self.mode, schema_type=other_type) - - @_exception_scopes.system_entry_point - def upload(self): - """ - Upload the schema to the remote location - """ - if "w" not in self.mode: - raise _user_exceptions.FlyteAssertion("Cannot upload a read-only schema!") - - elif not self.local_path: - raise _user_exceptions.FlyteAssertion( - "The schema is not currently backed by a local directory " - "and therefore cannot be uploaded. Please write to this before " - "attempting an upload." - ) - else: - # TODO: Introduce system logging - # logging.info("Putting {} -> {}".format(self.local_path, self.remote_location)) - _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=True) diff --git a/flytekit/common/types/primitives.py b/flytekit/common/types/primitives.py deleted file mode 100644 index 446140dc2c..0000000000 --- a/flytekit/common/types/primitives.py +++ /dev/null @@ -1,595 +0,0 @@ -import datetime as _datetime -import json as _json -import typing - -import six as _six -from dateutil import parser as _parser -from google.protobuf import json_format as _json_format -from google.protobuf import struct_pb2 as _struct -from pytimeparse import parse as _parse_duration_string - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types - - -class Integer(_base_sdk_types.FlyteSdkValue): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: Integer - """ - try: - return cls(int(string_value)) - except (ValueError, TypeError): - raise _user_exceptions.FlyteTypeException( - _six.text_type, - int, - additional_msg="String not castable to Integer SDK type:" " {}".format(string_value), - ) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - if type(t_value) not in _six.integer_types: - raise _user_exceptions.FlyteTypeException(type(t_value), _six.integer_types, t_value) - return cls(t_value) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.INTEGER) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Integer - """ - return cls(literal_model.scalar.primitive.integer) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Integer" - - def __init__(self, value): - """ - :param int value: Int value to wrap - """ - super(Integer, self).__init__(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=value))) - - def to_python_std(self): - """ - :rtype: int - """ - return self.scalar.primitive.integer - - def short_string(self): - """ - :rtype: Text - """ - return "Integer({})".format(self.scalar.primitive.integer) - - -class Float(_base_sdk_types.FlyteSdkValue): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: Float - """ - try: - return cls(float(string_value)) - except ValueError: - raise _user_exceptions.FlyteTypeException( - _six.text_type, - float, - additional_msg="String not castable to Float SDK type:" " {}".format(string_value), - ) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - if type(t_value) != float: - raise _user_exceptions.FlyteTypeException(type(t_value), float, t_value) - return cls(t_value) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.FLOAT) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Float - """ - return cls(literal_model.scalar.primitive.float_value) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Float" - - def __init__(self, value): - """ - :param float value: value to wrap - """ - super(Float, self).__init__(scalar=_literals.Scalar(primitive=_literals.Primitive(float_value=value))) - - def to_python_std(self): - """ - :rtype: float - """ - return self.scalar.primitive.float_value - - def short_string(self): - """ - :rtype: Text - """ - return "Float({})".format(self.scalar.primitive.float_value) - - -class Boolean(_base_sdk_types.FlyteSdkValue): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: Boolean - """ - if string_value == "1" or string_value.lower() == "true": - return cls(True) - elif string_value == "0" or string_value.lower() == "false": - return cls(False) - raise _user_exceptions.FlyteTypeException( - _six.text_type, - bool, - additional_msg="String not castable to Boolean SDK " "type: {}".format(string_value), - ) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - if type(t_value) != bool: - raise _user_exceptions.FlyteTypeException(type(t_value), bool, t_value) - return cls(t_value) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.BOOLEAN) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: bool - """ - return cls(literal_model.scalar.primitive.boolean) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Boolean" - - def __init__(self, value): - """ - :param bool value: value to wrap - """ - super(Boolean, self).__init__(scalar=_literals.Scalar(primitive=_literals.Primitive(boolean=value))) - - def to_python_std(self): - """ - :rtype: bool - """ - return self.scalar.primitive.boolean - - def short_string(self): - """ - :rtype: Text - """ - return "Boolean({})".format(self.scalar.primitive.boolean) - - -class String(_base_sdk_types.FlyteSdkValue): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: String - """ - if type(string_value) == dict or type(string_value) == list: - raise _user_exceptions.FlyteTypeException( - type(string_value), - _six.text_type, - additional_msg="Should not cast native Python type to string {}".format(string_value), - ) - return cls(string_value) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - Creates an object of this type from the model primitive defining it. - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - if type(t_value) not in set([str, _six.text_type]): - raise _user_exceptions.FlyteTypeException(type(t_value), set([str, _six.text_type]), t_value) - return cls(t_value) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRING) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: String - """ - return cls(literal_model.scalar.primitive.string_value) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "String" - - def __init__(self, value): - """ - :param Text value: value to wrap - """ - super(String, self).__init__(scalar=_literals.Scalar(primitive=_literals.Primitive(string_value=value))) - - def to_python_std(self): - """ - :rtype: Text - """ - return self.scalar.primitive.string_value - - def short_string(self): - """ - :rtype: Text - """ - _TRUNCATE_LENGTH = 100 - return "String('{}'{})".format( - self.scalar.primitive.string_value[:_TRUNCATE_LENGTH], - " ..." if len(self.scalar.primitive.string_value) > _TRUNCATE_LENGTH else "", - ) - - def verbose_string(self): - """ - :rtype: Text - """ - return "String('{}')".format(self.scalar.primitive.string_value) - - -class Datetime(_base_sdk_types.FlyteSdkValue): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: Datetime - """ - try: - python_std_datetime = _parser.parse(string_value) - except ValueError: - raise _user_exceptions.FlyteTypeException( - _six.text_type, - _datetime.datetime, - additional_msg="String not castable to Datetime " "SDK type: {}".format(string_value), - ) - - return cls.from_python_std(python_std_datetime) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif type(t_value) != _datetime.datetime: - raise _user_exceptions.FlyteTypeException(type(t_value), _datetime.datetime, t_value) - elif t_value.tzinfo is None: - raise _user_exceptions.FlyteValueException( - t_value, - "Datetime objects in Flyte must be timezone aware. " "tzinfo was found to be None.", - ) - return cls(t_value) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.DATETIME) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Datetime - """ - return cls(literal_model.scalar.primitive.datetime) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Datetime" - - def __init__(self, value): - """ - :param datetime.datetime value: value to wrap - """ - super(Datetime, self).__init__(scalar=_literals.Scalar(primitive=_literals.Primitive(datetime=value))) - - def to_python_std(self): - """ - :rtype: datetime.datetime - """ - return self.scalar.primitive.datetime - - def short_string(self): - """ - :rtype: Text - """ - return "Datetime({})".format(_six.text_type(self.scalar.primitive.datetime)) - - -class Timedelta(_base_sdk_types.FlyteSdkValue): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: Uses https://github.com/wroberts/pytimeparse for parsing - :rtype: Timedelta - """ - td = _parse_duration_string(string_value) - if td is None: - raise _user_exceptions.FlyteTypeException( - _six.text_type, - _datetime.timedelta, - additional_msg="Could not convert string to" " time delta: {}".format(string_value), - ) - return cls.from_python_std(_datetime.timedelta(seconds=td)) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif type(t_value) != _datetime.timedelta: - raise _user_exceptions.FlyteTypeException(type(t_value), _datetime.timedelta, t_value) - - return cls(t_value) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.DURATION) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Timedelta - """ - return cls(literal_model.scalar.primitive.duration) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Timedelta" - - def __init__(self, value): - """ - :param datetime.timedelta value: value to wrap - """ - super(Timedelta, self).__init__(scalar=_literals.Scalar(primitive=_literals.Primitive(duration=value))) - - def to_python_std(self): - """ - :rtype: datetime.timedelta - """ - return self.scalar.primitive.duration - - def short_string(self): - """ - :rtype: Text - """ - return "Timedelta({})".format(self.scalar.primitive.duration) - - -class Generic(_base_sdk_types.FlyteSdkValue): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: Should be a JSON formatted string - :rtype: Generic - """ - try: - t = _json_format.Parse(string_value, _struct.Struct()) - except Exception: - raise _user_exceptions.FlyteValueException(string_value, "Could not be parsed from JSON.") - return cls(t) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif not isinstance(t_value, dict): - raise _user_exceptions.FlyteTypeException(type(t_value), dict, t_value) - - try: - t = _json.dumps(t_value) - except Exception: - raise _user_exceptions.FlyteValueException(t_value, "Is not JSON serializable.") - - return cls(_json_format.Parse(t, _struct.Struct())) - - @classmethod - def to_flyte_literal_type(cls, metadata: typing.Dict = None): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT, metadata=metadata) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Generic - """ - return cls(literal_model.scalar.generic) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Generic" - - def __init__(self, value): - """ - :param _struct.Struct value: value to wrap - """ - super(Generic, self).__init__(scalar=_literals.Scalar(generic=value)) - - def to_python_std(self): - """ - :rtype: dict[Text, T] - """ - return _json.loads(_json_format.MessageToJson(self.scalar.generic)) - - def short_string(self): - """ - :rtype: Text - """ - return "Generic({})".format(self.to_python_std()) - - def long_string(self): - """ - :rtype: Text - """ - return "Generic({})".format(self.to_python_std()) diff --git a/flytekit/common/types/proto.py b/flytekit/common/types/proto.py deleted file mode 100644 index 2c3423ccfb..0000000000 --- a/flytekit/common/types/proto.py +++ /dev/null @@ -1,319 +0,0 @@ -import base64 as _base64 -from typing import Type, Union - -import six as _six -from google.protobuf import reflection as _proto_reflection -from google.protobuf.json_format import Error -from google.protobuf.json_format import MessageToDict as _MessageToDict -from google.protobuf.json_format import ParseDict as _ParseDict -from google.protobuf.reflection import GeneratedProtocolMessageType -from google.protobuf.struct_pb2 import Struct - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types -from flytekit.models.common import FlyteIdlEntity, FlyteType -from flytekit.models.types import LiteralType - -ProtobufT = Type[_proto_reflection.GeneratedProtocolMessageType] - - -class ProtobufType(_base_sdk_types.FlyteSdkType): - _pb_type = Struct - - @property - def pb_type(cls) -> GeneratedProtocolMessageType: - """ - :rtype: GeneratedProtocolMessageType - """ - return cls._pb_type - - @property - def descriptor(cls): - """ - :rtype: Text - """ - return "{}.{}".format(cls.pb_type.__module__, cls.pb_type.__name__) - - @property - def tag(cls): - """ - :rtype: Text - """ - return "{}{}".format(Protobuf.TAG_PREFIX, cls.descriptor) - - -class Protobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType): - PB_FIELD_KEY = "pb_type" - TAG_PREFIX = "{}=".format(PB_FIELD_KEY) - - def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): - """ - :param Union[T, FlyteIdlEntity] pb_object: - """ - v = pb_object - # This section converts an existing proto object (or a subclass of) to the right type expected by this instance - # of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it - # a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final - # struct. - # If the provided object has to_flyte_idl(), call it to produce a raw proto. - if isinstance(pb_object, FlyteIdlEntity): - v = pb_object.to_flyte_idl() - - # A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to - # convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class - # is initialized with one. - expected_type = type(self).pb_type - if expected_type != type(v) and expected_type != type(pb_object): - if isinstance(type(self).pb_type, FlyteType): - v = expected_type.from_flyte_idl(v).to_flyte_idl() - else: - raise _user_exceptions.FlyteTypeException( - received_type=type(pb_object), expected_type=expected_type, received_value=pb_object - ) - data = v.SerializeToString() - super(Protobuf, self).__init__( - scalar=_literals.Scalar( - binary=_literals.Binary(value=bytes(data) if _six.PY2 else data, tag=type(self).tag) - ) - ) - - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: b64 encoded string of bytes - :rtype: Protobuf - """ - try: - decoded = _base64.b64decode(string_value) - except TypeError: - raise _user_exceptions.FlyteValueException(string_value, "The string is not valid base64-encoded.") - pb_obj = cls.pb_type() - pb_obj.ParseFromString(decoded) - return cls(pb_obj) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return isinstance(other, ProtobufType) and other.pb_type is cls.pb_type - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: _base_sdk_types.FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity): - return cls(t_value) - else: - raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.BINARY, metadata={cls.PB_FIELD_KEY: cls.descriptor}) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Protobuf - """ - if literal_model.scalar.binary.tag != cls.tag: - raise _user_exceptions.FlyteTypeException( - literal_model.scalar.binary.tag, - cls.pb_type, - received_value=_base64.b64encode(literal_model.scalar.binary.value), - additional_msg="Can not deserialize as proto tags don't match.", - ) - pb_obj = cls.pb_type() - pb_obj.ParseFromString(literal_model.scalar.binary.value) - return cls(pb_obj) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return "Types.Proto({})".format(cls.descriptor) - - def to_python_std(self): - """ - :returns: The protobuf object as defined by the user. - :rtype: T - """ - pb_obj = type(self).pb_type() - pb_obj.ParseFromString(self.scalar.binary.value) - return pb_obj - - def short_string(self): - """ - :rtype: Text - """ - return "{}".format(self.to_python_std()) - - -def create_protobuf(pb_type: Type[GeneratedProtocolMessageType]) -> Type[Protobuf]: - """ - :param Type[GeneratedProtocolMessageType] pb_type: - :rtype: Type[Protobuf] - """ - if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType): - raise _user_exceptions.FlyteTypeException( - expected_type=_proto_reflection.GeneratedProtocolMessageType, - received_type=type(pb_type), - received_value=pb_type, - ) - - class _Protobuf(Protobuf): - _pb_type = pb_type - - return _Protobuf - - -class GenericProtobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType): - PB_FIELD_KEY = "pb_type" - TAG_PREFIX = "{}=".format(PB_FIELD_KEY) - - def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): - """ - :param Union[T, FlyteIdlEntity] pb_object: - """ - struct = Struct() - v = pb_object - - # This section converts an existing proto object (or a subclass of) to the right type expected by this instance - # of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it - # a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final - # struct. - # If the provided object has to_flyte_idl(), call it to produce a raw proto. - if isinstance(pb_object, FlyteIdlEntity): - v = pb_object.to_flyte_idl() - - # A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to - # convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class - # is initialized with one. - expected_type = type(self).pb_type - if expected_type != type(v) and expected_type != type(pb_object): - if isinstance(type(self).pb_type, FlyteType): - v = expected_type.from_flyte_idl(v).to_flyte_idl() - else: - raise _user_exceptions.FlyteTypeException( - received_type=type(pb_object), expected_type=expected_type, received_value=pb_object - ) - - struct.update(_MessageToDict(v)) - super().__init__(scalar=_literals.Scalar(generic=struct)) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return isinstance(other, ProtobufType) and other.pb_type is cls.pb_type - - @classmethod - def from_python_std(cls, t_value: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): - """ - :param Union[T, FlyteIdlEntity] t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: _base_sdk_types.FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity): - return cls(t_value) - else: - raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value) - - @classmethod - def to_flyte_literal_type(cls) -> LiteralType: - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT, metadata={cls.PB_FIELD_KEY: cls.descriptor}) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Protobuf - """ - pb_obj = cls.pb_type() - try: - dictionary = _MessageToDict(literal_model.scalar.generic) - pb_obj = _ParseDict(dictionary, pb_obj) - except Error as err: - raise _user_exceptions.FlyteTypeException( - received_type="generic", - expected_type=cls.pb_type, - received_value=_base64.b64encode(literal_model.scalar.generic), - additional_msg=f"Can not deserialize. Error: {err.__str__()}", - ) - - return cls(pb_obj) - - @classmethod - def short_class_string(cls) -> str: - """ - :rtype: Text - """ - return "Types.GenericProto({})".format(cls.descriptor) - - def to_python_std(self): - """ - :returns: The protobuf object as defined by the user. - :rtype: T - """ - pb_obj = type(self).pb_type() - try: - dictionary = _MessageToDict(self.scalar.generic) - pb_obj = _ParseDict(dictionary, pb_obj) - except Error as err: - raise _user_exceptions.FlyteTypeException( - received_type="generic", - expected_type=type(self).pb_type, - received_value=_base64.b64encode(self.scalar.generic), - additional_msg=f"Can not deserialize. Error: {err.__str__()}", - ) - return pb_obj - - def short_string(self) -> str: - """ - :rtype: Text - """ - return "{}".format(self.to_python_std()) - - -def create_generic(pb_type: Type[GeneratedProtocolMessageType]) -> Type[GenericProtobuf]: - """ - Creates a generic protobuf type that represents protobuf type ProtobufT and that will get serialized into a struct. - - :param Type[GeneratedProtocolMessageType] pb_type: - :rtype: Type[GenericProtobuf] - """ - if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType) and not issubclass( - pb_type, FlyteIdlEntity - ): - raise _user_exceptions.FlyteTypeException( - expected_type=_proto_reflection.GeneratedProtocolMessageType, - received_type=type(pb_type), - received_value=pb_type, - ) - - class _Protobuf(GenericProtobuf): - _pb_type = pb_type - - return _Protobuf diff --git a/flytekit/common/types/schema.py b/flytekit/common/types/schema.py deleted file mode 100644 index eaf38d1c88..0000000000 --- a/flytekit/common/types/schema.py +++ /dev/null @@ -1,189 +0,0 @@ -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.common.types.impl import schema as _schema_impl -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types - - -class SchemaInstantiator(_base_sdk_types.InstantiableType): - def create_at_known_location(cls, location): - """ - :param Text location: - :rtype: flytekit.common.types.impl.schema.Schema - """ - return _schema_impl.Schema.create_at_known_location(location, mode="wb", schema_type=cls.schema_type) - - def fetch(cls, remote_path, local_path=None): - """ - :param Text remote_path: - :param Text local_path: [Optional] If specified, the Schema is copied to this location. If specified, - this location is NOT managed and the schema will not be cleaned up upon exit. - :rtype: flytekit.common.types.impl.schema.Schema - """ - return _schema_impl.Schema.fetch(remote_path, mode="rb", local_path=local_path, schema_type=cls.schema_type) - - def create(cls): - """ - :rtype: flytekit.common.types.impl.schema.Schema - """ - return _schema_impl.Schema.create_at_any_location(mode="wb", schema_type=cls.schema_type) - - def create_from_hive_query( - cls, - select_query, - stage_query=None, - schema_to_table_name_map=None, - known_location=None, - ): - """ - Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed - schema object. - - :param Text select_query: Query for selecting data from Hive - :param Text stage_query: Query for building temporary tables on Hive. - Runs before the select query. Temporary tables are supported but CTEs are not supported. - :param Dict[Text, Text] schema_to_table_name_map: A map of column names in the schema to the column names - returned from the select query - :param Text known_location: create the schema object at a known s3 location. - :rtype: flytekit.common.types.impl.schema.Schema, Text - """ - return _schema_impl.Schema.create_from_hive_query( - select_query=select_query, - stage_query=stage_query, - schema_to_table_name_map=schema_to_table_name_map, - known_location=known_location, - schema_type=cls.schema_type, - ) - - def __call__(cls, *args, **kwargs): - """ - TODO: Is there a better way to deal with this? - - :rtype: flytekit.common.types.impl.schema.Schema - """ - if not args and not kwargs: - return _schema_impl.Schema.create_at_any_location(mode="wb", schema_type=cls.schema_type) - else: - return super(SchemaInstantiator, cls).__call__(*args, **kwargs) - - @property - def schema_type(cls): - """ - :rtype: _schema_impl.SchemaType - """ - return cls._schema_type - - @property - def columns(cls): - """ - :rtype: dict[Text, flytekit.common.types.base_sdk_types.FlyteSdkType] - """ - return cls.schema_type.sdk_columns - - -class Schema(_base_sdk_types.FlyteSdkValue, metaclass=SchemaInstantiator): - @classmethod - def from_string(cls, string_value): - """ - :param Text string_value: - :rtype: Schema - """ - if not string_value: - _user_exceptions.FlyteValueException(string_value, "Cannot create a Schema from an empty path") - return cls(_schema_impl.Schema.from_string(string_value, schema_type=cls.schema_type)) - - @classmethod - def is_castable_from(cls, other): - """ - :param flytekit.common.types.base_literal_types.FlyteSdkType other: - :rtype: bool - """ - return cls == other - - @classmethod - def from_python_std(cls, t_value): - """ - :param T t_value: It is up to each individual object as to whether or not this value can be cast. - :rtype: FlyteSdkValue - :raises: flytekit.common.exceptions.user.FlyteTypeException - """ - if t_value is None: - return _base_sdk_types.Void() - elif isinstance(t_value, _schema_impl.Schema): - schema = t_value.cast_to(cls.schema_type) - else: - schema = _schema_impl.Schema.from_python_std(t_value, schema_type=cls.schema_type) - return cls(schema) - - @classmethod - def to_flyte_literal_type(cls): - """ - :rtype: flytekit.models.types.LiteralType - """ - return _idl_types.LiteralType(schema=cls.schema_type) - - @classmethod - def promote_from_model(cls, literal_model): - """ - Creates an object of this type from the model primitive defining it. - :param flytekit.models.literals.Literal literal_model: - :rtype: Schema - """ - return cls(_schema_impl.Schema.promote_from_model(literal_model.scalar.schema)) - - @classmethod - def short_class_string(cls): - """ - :rtype: Text - """ - return repr(cls.schema_type) - - def __init__(self, value): - """ - :param flytekit.common.types.impl.schema.Schema value: Schema value to wrap - """ - super(Schema, self).__init__(scalar=_literals.Scalar(schema=value)) - - def to_python_std(self): - """ - :rtype: flytekit.common.types.impl.schema.Schema - """ - return self.scalar.schema - - def short_string(self): - """ - :rtype: Text - """ - return "{}".format( - self.scalar.schema, - ) - - -def schema_instantiator(columns=None): - """ - :param list[(Text, flytekit.common.types.base_sdk_types.FlyteSdkType)] columns: [Optional] Description of the - columns in the underlying schema. Should be tuples with the first element being the name. - :rtype: SchemaInstantiator - """ - if columns is not None and len(columns) == 0: - raise _user_exceptions.FlyteValueException( - columns, - "When specifying a Schema type with a known set of columns, a non-empty list must be provided as " "inputs", - ) - - class _Schema(Schema, metaclass=SchemaInstantiator): - _schema_type = _schema_impl.SchemaType(columns=columns) - - return _Schema - - -def schema_instantiator_from_proto(schema_type): - """ - :param flytekit.models.types.SchemaType schema_type: - :rtype: SchemaInstantiator - """ - - class _Schema(Schema, metaclass=SchemaInstantiator): - _schema_type = _schema_impl.SchemaType.promote_from_model(schema_type) - - return _Schema diff --git a/flytekit/common/workflow.py b/flytekit/common/workflow.py deleted file mode 100644 index 3fc4f498fd..0000000000 --- a/flytekit/common/workflow.py +++ /dev/null @@ -1,309 +0,0 @@ -import datetime as _datetime -from typing import List - -from flytekit.common import constants as _constants -from flytekit.common import interface as _interface -from flytekit.common import nodes as _nodes -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.launch_plan import SdkLaunchPlan -from flytekit.common.mixins import hash as _hash_mixin -from flytekit.common.mixins import registerable as _registerable -from flytekit.configuration import auth as _auth_config -from flytekit.configuration import internal as _internal_config -from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import common as _common_models -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import literals as _literal_models -from flytekit.models import schedule as _schedule_models -from flytekit.models.admin import workflow as _admin_workflow_model -from flytekit.models.core import identifier as _identifier_model -from flytekit.models.core import workflow as _workflow_models - - -class SdkWorkflow( - _hash_mixin.HashOnReferenceMixin, - _registerable.HasDependencies, - _registerable.RegisterableEntity, - _workflow_models.WorkflowTemplate, - metaclass=_sdk_bases.ExtendedSdkType, -): - """ - Previously this class represented both local and control plane constructs. As of this writing, we are making this - class only a control plane class. Workflow constructs that rely on local code being present have been moved to - the new PythonWorkflow class. - """ - - def __init__( - self, - nodes, - interface, - output_bindings, - id, - metadata, - metadata_defaults, - ): - """ - :param list[flytekit.common.nodes.SdkNode] nodes: - :param flytekit.models.interface.TypedInterface interface: Defines a strongly typed interface for the - Workflow (inputs, outputs). This can include some optional parameters. - :param list[flytekit.models.literals.Binding] output_bindings: A list of output bindings that specify how to construct - workflow outputs. Bindings can pull node outputs or specify literals. All workflow outputs specified in - the interface field must be bound - in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes - to execute successfully in order to bind final outputs. - :param flytekit.models.core.identifier.Identifier id: This is an autogenerated id by the system. The id is - globally unique across Flyte. - :param WorkflowMetadata metadata: This contains information on how to run the workflow. - :param flytekit.models.core.workflow.WorkflowMetadataDefaults metadata_defaults: Defaults to be passed - to nodes contained within workflow. - """ - for n in nodes: - for upstream in n.upstream_nodes: - if upstream.id is None: - raise _user_exceptions.FlyteAssertion( - "Some nodes contained in the workflow were not found in the workflow description. Please " - "ensure all nodes are either assigned to attributes within the class or an element in a " - "list, dict, or tuple which is stored as an attribute in the class." - ) - - super(SdkWorkflow, self).__init__( - id=id, - metadata=metadata, - metadata_defaults=metadata_defaults, - interface=interface, - nodes=nodes, - outputs=output_bindings, - ) - self._sdk_nodes = nodes - self._has_registered = False - - @property - def upstream_entities(self): - return set(n.executable_sdk_object for n in self._sdk_nodes) - - @property - def interface(self): - """ - :rtype: flytekit.common.interface.TypedInterface - """ - return super(SdkWorkflow, self).interface - - @property - def entity_type_text(self): - """ - :rtype: Text - """ - return "Workflow" - - @property - def resource_type(self): - """ - Integer from _identifier.ResourceType enum - :rtype: int - """ - return _identifier_model.ResourceType.WORKFLOW - - def get_sub_workflows(self): - """ - Recursive call that returns all subworkflows in the current workflow - - :rtype: list[SdkWorkflow] - """ - result = [] - for node in self.nodes: - if node.workflow_node is not None and node.workflow_node.sub_workflow_ref is not None: - if node.executable_sdk_object is not None and node.executable_sdk_object.entity_type_text == "Workflow": - result.append(node.executable_sdk_object) - result.extend(node.executable_sdk_object.get_sub_workflows()) - else: - raise _system_exceptions.FlyteSystemException( - "workflow node with subworkflow found but bad executable " - "object {}".format(node.executable_sdk_object) - ) - - # get subworkflows in conditional branches - if node.branch_node is not None: - if_else: _workflow_models.IfElseBlock = node.branch_node.if_else - leaf_nodes: List[_nodes.SdkNode] = filter( - None, - [ - if_else.case.then_node, - *([] if if_else.other is None else [x.then_node for x in if_else.other]), - if_else.else_node, - ], - ) - for leaf_node in leaf_nodes: - exec_sdk_obj = leaf_node.executable_sdk_object - if exec_sdk_obj is not None and exec_sdk_obj.entity_type_text == "Workflow": - result.append(exec_sdk_obj) - result.extend(exec_sdk_obj.get_sub_workflows()) - - return result - - @classmethod - @_exception_scopes.system_entry_point - def fetch(cls, project, domain, name, version=None): - """ - This function uses the engine loader to call create a hydrated task from Admin. - :param Text project: - :param Text domain: - :param Text name: - :param Text version: - :rtype: SdkWorkflow - """ - version = version or _internal_config.VERSION.get() - workflow_id = _identifier.Identifier(_identifier_model.ResourceType.WORKFLOW, project, domain, name, version) - admin_workflow = _flyte_engine.get_client().get_workflow(workflow_id) - cwc = admin_workflow.closure.compiled_workflow - primary_template = cwc.primary.template - sub_workflow_map = {sw.template.id: sw.template for sw in cwc.sub_workflows} - task_map = {t.template.id: t.template for t in cwc.tasks} - sdk_workflow = cls.promote_from_model(primary_template, sub_workflow_map, task_map) - sdk_workflow._id = workflow_id - sdk_workflow._has_registered = True - return sdk_workflow - - @classmethod - def get_non_system_nodes(cls, nodes): - """ - :param list[flytekit.models.core.workflow.Node] nodes: - :rtype: list[flytekit.models.core.workflow.Node] - """ - return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] - - @classmethod - def promote_from_model(cls, base_model, sub_workflows=None, tasks=None): - """ - :param flytekit.models.core.workflow.WorkflowTemplate base_model: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.workflow.WorkflowTemplate] - sub_workflows: Provide a list of WorkflowTemplate - models (should be returned from Admin as part of the admin CompiledWorkflowClosure. Relevant sub-workflows - should always be provided. - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: Same as above - but for tasks. If tasks are not provided relevant TaskTemplates will be fetched from Admin - :rtype: SdkWorkflow - """ - base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) - sub_workflows = sub_workflows or {} - tasks = tasks or {} - node_map = { - n.id: _nodes.SdkNode.promote_from_model(n, sub_workflows, tasks) for n in base_model_non_system_nodes - } - - # Set upstream nodes for each node - for n in base_model_non_system_nodes: - current = node_map[n.id] - for upstream_id in current.upstream_node_ids: - upstream_node = node_map[upstream_id] - current << upstream_node - - # No inputs/outputs specified, see the constructor for more information on the overrides. - return cls( - nodes=list(node_map.values()), - id=_identifier.Identifier.promote_from_model(base_model.id), - metadata=base_model.metadata, - metadata_defaults=base_model.metadata_defaults, - interface=_interface.TypedInterface.promote_from_model(base_model.interface), - output_bindings=base_model.outputs, - ) - - @_exception_scopes.system_entry_point - def register(self, project, domain, name, version): - """ - :param Text project: - :param Text domain: - :param Text name: - :param Text version: - """ - self.validate() - id_to_register = _identifier.Identifier(_identifier_model.ResourceType.WORKFLOW, project, domain, name, version) - old_id = self.id - self._id = id_to_register - try: - client = _flyte_engine.get_client() - sub_workflows = self.get_sub_workflows() - client.create_workflow( - id_to_register, - _admin_workflow_model.WorkflowSpec( - self, - sub_workflows, - ), - ) - self._id = id_to_register - self._has_registered = True - return str(id_to_register) - except _user_exceptions.FlyteEntityAlreadyExistsException: - pass - except Exception: - self._id = old_id - raise - - @_exception_scopes.system_entry_point - def serialize(self): - """ - Serializing a workflow should produce an object similar to what the registration step produces, in preparation - for actual registration to Admin. - - :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec - """ - sub_workflows = self.get_sub_workflows() - return _admin_workflow_model.WorkflowSpec( - self, - sub_workflows, - ).to_flyte_idl() - - @_exception_scopes.system_entry_point - def validate(self): - pass - - @_exception_scopes.system_entry_point - def create_launch_plan(self, *args, **kwargs): - # TODO: Correct after implementing new launch plan - assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() - kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() - - if not (assumable_iam_role or kubernetes_service_account): - raise _user_exceptions.FlyteValidationException("No assumable role or service account found") - auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account, - ) - - return SdkLaunchPlan( - workflow_id=self.id, - entity_metadata=_launch_plan_models.LaunchPlanMetadata( - schedule=_schedule_models.Schedule(""), - notifications=[], - ), - default_inputs=_interface_models.ParameterMap({}), - fixed_inputs=_literal_models.LiteralMap(literals={}), - labels=_common_models.Labels({}), - annotations=_common_models.Annotations({}), - auth_role=auth_role, - raw_output_data_config=_common_models.RawOutputDataConfig(""), - ) - - @_exception_scopes.system_entry_point - def __call__(self, *args, **input_map): - if len(args) > 0: - raise _user_exceptions.FlyteAssertion( - "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only. We " - "detected {} positional args.".format(len(args)) - ) - bindings, upstream_nodes = self.interface.create_bindings_for_inputs(input_map) - - node = _nodes.SdkNode( - id=None, - metadata=_workflow_models.NodeMetadata( - "placeholder", _datetime.timedelta(), _literal_models.RetryStrategy(0) - ), - upstream_nodes=upstream_nodes, - bindings=sorted(bindings, key=lambda b: b.var), - sdk_workflow=self, - ) - return node diff --git a/flytekit/common/workflow_execution.py b/flytekit/common/workflow_execution.py deleted file mode 100644 index 14695d0e68..0000000000 --- a/flytekit/common/workflow_execution.py +++ /dev/null @@ -1,183 +0,0 @@ -import os as _os - -import six as _six -from flyteidl.core import literals_pb2 as _literals_pb2 - -from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions -from flytekit.common import nodes as _nodes -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common import utils as _common_utils -from flytekit.common.core import identifier as _core_identifier -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import artifact as _artifact -from flytekit.common.types import helpers as _type_helpers -from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import execution as _execution_models -from flytekit.models import literals as _literal_models -from flytekit.models.core import execution as _core_execution_models - - -class SdkWorkflowExecution( - _execution_models.Execution, - _artifact.ExecutionArtifact, - metaclass=_sdk_bases.ExtendedSdkType, -): - def __init__(self, *args, **kwargs): - super(SdkWorkflowExecution, self).__init__(*args, **kwargs) - self._node_executions = None - self._inputs = None - self._outputs = None - - @property - def node_executions(self): - """ - :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] - """ - return self._node_executions or {} - - @property - def inputs(self): - """ - Returns the inputs to the execution in the standard Python format as dictated by the type engine. - :rtype: dict[Text, T] - """ - if self._inputs is None: - client = _flyte_engine.get_client() - execution_data = client.get_execution_data(self.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_inputs.literals): - input_map = execution_data.full_inputs - elif execution_data.inputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) - input_map = _literal_models.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - else: - input_map = _literal_models.LiteralMap({}) - - self._inputs = _type_helpers.unpack_literal_map_to_sdk_python_std(input_map) - return self._inputs - - @property - def outputs(self): - """ - Returns the outputs to the execution in the standard Python format as dictated by the type engine. If the - execution ended in error or the execution is in progress, an exception will be raised. - :rtype: dict[Text, T] or None - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please what until the node execution has completed before " "requesting the outputs." - ) - if self.error: - raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") - - if self._outputs is None: - client = _flyte_engine.get_client() - - execution_data = client.get_execution_data(self.id) - # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_outputs.literals): - output_map = execution_data.full_outputs - - elif execution_data.outputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) - output_map = _literal_models.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - else: - output_map = _literal_models.LiteralMap({}) - - self._outputs = _type_helpers.unpack_literal_map_to_sdk_python_std(output_map) - return self._outputs - - @property - def error(self): - """ - If execution is in progress, raise an exception. Otherwise, return None if no error was present upon - reaching completion. - :rtype: flytekit.models.core.execution.ExecutionError or None - """ - if not self.is_complete: - raise _user_exceptions.FlyteAssertion( - "Please wait until a workflow has completed before checking for an " "error." - ) - return self.closure.error - - @property - def is_complete(self): - """ - Dictates whether or not the execution is complete. - :rtype: bool - """ - return self.closure.phase in { - _core_execution_models.WorkflowExecutionPhase.ABORTED, - _core_execution_models.WorkflowExecutionPhase.FAILED, - _core_execution_models.WorkflowExecutionPhase.SUCCEEDED, - _core_execution_models.WorkflowExecutionPhase.TIMED_OUT, - } - - @classmethod - def promote_from_model(cls, base_model): - """ - :param _execution_models.Execution base_model: - :rtype: SdkWorkflowExecution - """ - return cls( - closure=base_model.closure, - id=_core_identifier.WorkflowExecutionIdentifier.promote_from_model(base_model.id), - spec=base_model.spec, - ) - - @classmethod - def fetch(cls, project, domain, name): - """ - :param Text project: - :param Text domain: - :param Text name: - :rtype: SdkWorkflowExecution - """ - wf_exec_id = _core_identifier.WorkflowExecutionIdentifier(project=project, domain=domain, name=name) - admin_exec = _flyte_engine.get_client().get_execution(wf_exec_id) - - return cls.promote_from_model(admin_exec) - - def sync(self): - """ - Syncs the state of the underlying execution artifact with the state observed by the platform. - :rtype: None - """ - if not self.is_complete or self._node_executions is None: - self._sync_closure() - self._node_executions = self.get_node_executions() - - def _sync_closure(self): - """ - Syncs the closure of the underlying execution artifact with the state observed by the platform. - :rtype: None - """ - if not self.is_complete: - client = _flyte_engine.get_client() - self._closure = client.get_execution(self.id).closure - - def get_node_executions(self, filters=None): - """ - :param list[flytekit.models.filters.Filter] filters: - :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] - """ - client = _flyte_engine.get_client() - node_exec_models = {v.id.node_id: v for v in _iterate_node_executions(client, self.id, filters=filters)} - - return {k: _nodes.SdkNodeExecution.promote_from_model(v) for k, v in _six.iteritems(node_exec_models)} - - def terminate(self, cause): - """ - :param Text cause: - """ - _flyte_engine.get_client().terminate_execution(self.id, cause) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index f6fc2c633e..6dd9c28d35 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -1,12 +1,6 @@ import logging as _logging import os as _os - -import six as _six - -try: - import pathlib as _pathlib -except ImportError: - import pathlib2 as _pathlib # python 2 backport +import pathlib as _pathlib def set_flyte_config_file(config_file_path): @@ -38,7 +32,7 @@ def __init__(self, new_config_path, internal_overrides=None): import flytekit.configuration.common as _common self._internal_overrides = { - _common.format_section_key("internal", k): v for k, v in _six.iteritems(internal_overrides or {}) + _common.format_section_key("internal", k): v for k, v in (internal_overrides or {}).items() } self._new_config_path = new_config_path self._old_config_path = None @@ -47,13 +41,13 @@ def __init__(self, new_config_path, internal_overrides=None): def __enter__(self): import flytekit.configuration.internal as _internal - self._old_internals = {k: _os.environ.get(k) for k in _six.iterkeys(self._internal_overrides)} + self._old_internals = {k: _os.environ.get(k) for k in self._internal_overrides.keys()} self._old_config_path = _os.environ.get(_internal.CONFIGURATION_PATH.env_var) _os.environ.update(self._internal_overrides) set_flyte_config_file(self._new_config_path) def __exit__(self, exc_type, exc_val, exc_tb): - for k, v in _six.iteritems(self._old_internals): + for k, v in self._old_internals.items(): if v is not None: _os.environ[k] = v else: diff --git a/flytekit/configuration/common.py b/flytekit/configuration/common.py index 93f5c2d403..6e0b2088a4 100644 --- a/flytekit/configuration/common.py +++ b/flytekit/configuration/common.py @@ -2,7 +2,7 @@ import configparser as _configparser import os as _os -from flytekit.common.exceptions import user as _user_exceptions +from flytekit.exceptions import user as _user_exceptions def format_section_key(section, key): diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py index d420e7a019..5c4061fa4f 100644 --- a/flytekit/configuration/platform.py +++ b/flytekit/configuration/platform.py @@ -1,5 +1,5 @@ -from flytekit.common import constants as _constants from flytekit.configuration import common as _config_common +from flytekit.core import constants as _constants URL = _config_common.FlyteStringConfigurationEntry("platform", "url") diff --git a/flytekit/contrib/__init__.py b/flytekit/contrib/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/contrib/sensors/__init__.py b/flytekit/contrib/sensors/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/contrib/sensors/base_sensor.py b/flytekit/contrib/sensors/base_sensor.py deleted file mode 100644 index 8522918a34..0000000000 --- a/flytekit/contrib/sensors/base_sensor.py +++ /dev/null @@ -1,90 +0,0 @@ -import abc as _abc -import datetime as _datetime -import logging as _logging -import sys as _sys -import time as _time -import traceback as _traceback - -import six as _six - - -class Sensor(object, metaclass=_abc.ABCMeta): - def __init__(self, evaluation_interval=None, max_failures=0): - """ - :param datetime.timedelta evaluation_interval: This is the time to wait between evaluation attempts of this - sensor. If the sensor takes longer to evaluate than the poll_interval, it will immediately begin - evaluation again. - :param int max_failures: This is the maximum number of failures that can happen while attempting to sense - before perma-failing. - """ - if evaluation_interval is None: - evaluation_interval = _datetime.timedelta(seconds=30) - self._evaluation_interval = evaluation_interval - self._max_failures = max_failures - self._failures = 0 - self._exc_info = None - self._last_executed_time = _datetime.datetime(year=1990, month=6, day=30) # Arbitrary date in the past. - self._sensed = False - - @_abc.abstractmethod - def _do_poll(self): - """ - :rtype: (bool, Optional[datetime.timedelta]) - """ - pass - - def sense_with_wait_hint(self): - """ - Attempts to sense based on the lambda expression. The method will return the last sensed result. If the - rate of sensing is exceeded for the sensor, the timedelta in the returned tuple will tell the caller how long it - should sleep before trying again. - :rtype: (bool, Optional[datetime.timedelta]) - """ - # Return cached success. This simplifies code for the conditional sensors. - if self._sensed: - return self._sensed, self._evaluation_interval - - # Perma-fail to prevent abuse of sensed objects. - if self._failures > self._max_failures: - _six.reraise(*self._exc_info) - - now = _datetime.datetime.utcnow() - - time_to_wait_eval_period = self._evaluation_interval - (now - self._last_executed_time) - if time_to_wait_eval_period > _datetime.timedelta(): - return self._sensed, time_to_wait_eval_period - - try: - self._sensed, time_to_wait = self._do_poll() - time_to_wait = time_to_wait or self._evaluation_interval - except BaseException: - self._failures += 1 - self._exc_info = _sys.exc_info() - if self._failures > self._max_failures: - _logging.error( - "{} failed (with no remaining retries) due to:\n\n{}".format(self, _traceback.format_exc()), - ) - raise - else: - _logging.warn("{} failed (but will retry) due to:\n\n{}".format(self, _traceback.format_exc())) - time_to_wait = self._evaluation_interval - - self._last_executed_time = _datetime.datetime.utcnow() - - return self._sensed, time_to_wait - - def sense(self, timeout=None): - """ - Attempts - :param datetime.timedelta timeout: - :rtype: bool - """ - started = _datetime.datetime.utcnow() - while True: - sensed, time_to_wait = self.sense_with_wait_hint() - if sensed: - return True - if time_to_wait: - _time.sleep(time_to_wait.total_seconds()) - if timeout is not None and (_datetime.datetime.utcnow() - started) > timeout: - return False diff --git a/flytekit/contrib/sensors/impl.py b/flytekit/contrib/sensors/impl.py deleted file mode 100644 index 8276320922..0000000000 --- a/flytekit/contrib/sensors/impl.py +++ /dev/null @@ -1,107 +0,0 @@ -from flytekit.contrib.sensors.base_sensor import Sensor as _Sensor -from flytekit.plugins import hmsclient as _hmsclient - - -class _HiveSensor(_Sensor): - def __init__(self, host, port, schema="default", **kwargs): - """ - :param Text host: - :param Text port: - :param Text schema: The schema/database that we should consider. - :param **kwargs: See flytekit.contrib.sensors.base_sensor.Sensor for more - parameters. - """ - self._schema = schema - self._host = host - self._port = port - self._hive_metastore_client = _hmsclient.HMSClient(host=host, port=port) - super(_HiveSensor, self).__init__(**kwargs) - - -class HiveTableSensor(_HiveSensor): - def __init__(self, table_name, host, port, **kwargs): - """ - :param Text host: The host for the Hive metastore Thrift service. - :param Text port: The port for the Hive metastore Thrift Service - :param Text table_name: The name of the table to look for. - :param **kwargs: See _HiveSensor and flytekit.contrib.sensors.base_sensor.Sensor for more - parameters. - """ - super(HiveTableSensor, self).__init__(host, port, **kwargs) - self._table_name = table_name - - def _do_poll(self): - """ - :rtype: (bool, Optional[datetime.timedelta]) - """ - with self._hive_metastore_client as client: - try: - client.get_table(self._schema, self._table_name) - return True, None - except _hmsclient.genthrift.hive_metastore.ttypes.NoSuchObjectException: - return False, None - - -class HiveNamedPartitionSensor(_HiveSensor): - def __init__(self, table_name, partition_names, host, port, **kwargs): - """ - This class allows sensing for a specific named Hive Partition. This is the preferred partition sensing - operator because it is more efficient than evaluating a filter expression. - - :param Text table_name: The name of the table - :param Text partition_name: The name of the partition to listen for (example: 'ds=2017-01-01/region=NYC') - :param Text host: The host for the Hive metastore Thrift service. - :param Text port: The port for the Hive metastore Thrift Service - :param **kwargs: See _HiveSensor and flytekit.contrib.sensors.base_sensor.Sensor for more - parameters. - """ - super(HiveNamedPartitionSensor, self).__init__(host, port, **kwargs) - self._table_name = table_name - self._partition_names = partition_names - - def _do_poll(self): - """ - :rtype: (bool, Optional[datetime.timedelta]) - """ - with self._hive_metastore_client as client: - try: - for partition_name in self._partition_names: - client.get_partition_by_name(self._schema, self._table_name, partition_name) - return True, None - except _hmsclient.genthrift.hive_metastore.ttypes.NoSuchObjectException: - return False, None - - -class HiveFilteredPartitionSensor(_HiveSensor): - def __init__(self, table_name, partition_filter, host, port, **kwargs): - """ - This class allows sensing for any Hive partition that matches a filter expression. It is recommended that the - user should use HiveNamedPartitionSensor instead when possible because it is a more efficient API. - - :param Text table_name: The name of the table - :param Text partition_filter: A filter expression for the partition. (example: "ds = '2017-01-01' and - region='NYC') - :param Text host: The host for the Hive metastore Thrift service. - :param Text port: The port for the Hive metastore Thrift Service - :param **kwargs: See _HiveSensor and flytekit.contrib.sensors.base_sensor.Sensor for more - parameters. - """ - super(HiveFilteredPartitionSensor, self).__init__(host, port, **kwargs) - self._table_name = table_name - self._partition_filter = partition_filter - - def _do_poll(self): - """ - :rtype: (bool, Optional[datetime.timedelta]) - """ - with self._hive_metastore_client as client: - partitions = client.get_partitions_by_filter( - db_name=self._schema, - tbl_name=self._table_name, - filter=self._partition_filter, - max_parts=1, - ) - if partitions: - return True, None - else: - return False, None diff --git a/flytekit/contrib/sensors/task.py b/flytekit/contrib/sensors/task.py deleted file mode 100644 index 0749fc39dc..0000000000 --- a/flytekit/contrib/sensors/task.py +++ /dev/null @@ -1,127 +0,0 @@ -from flytekit.common import constants as _common_constants -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.contrib.sensors.base_sensor import Sensor as _Sensor - - -class SensorTask(_sdk_runnable.SdkRunnableTask): - def _execute_user_code(self, context, inputs): - sensor = super(SensorTask, self)._execute_user_code(context=context, inputs=inputs) - if sensor is not None: - if not isinstance(sensor, _Sensor): - raise _user_exceptions.FlyteTypeException( - received_type=type(sensor), - expected_type=_Sensor, - ) - succeeded = sensor.sense() - if not succeeded: - raise _user_exceptions.FlyteRecoverableException() - - -def sensor_task( - _task_function=None, - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - timeout=None, - environment=None, - cls=None, -): - """ - Decorator to create a Sensor Task definition. This task will run as a single unit of work on the platform. - - .. code-block:: python - @sensor_task(retries=3) - def my_task(wf_params): - return HiveTableSensor( - schema='default', - table_name='mocked_table', - host='localhost', - port=1234, - ) - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs. No keyword - arguments are allowed for wrapped task functions. - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - .. note:: - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - :param bool interruptible: Specify whether task is interruptible - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - .. note:: - This is currently not supported by the platform. - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - .. note:: - This is currently not supported by the platform. - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to - inject bespoke logic into the base Flyte programming model. Ideally, should be a sub-class of SensorTask or - otherwise mimic the behavior. - :rtype: SensorTask - """ - - def wrapper(fn): - return (SensorTask or cls)( - task_function=fn, - task_type=_common_constants.SdkTaskType.SENSOR_TASK, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - timeout=timeout, - environment=environment, - custom={}, - discovery_version="", - discoverable=False, - cache_serializable=False, - ) - - # This is syntactic-sugar, so that when calling this decorator without args, you can either - # do it with () or without any () - if _task_function: - return wrapper(_task_function) - else: - return wrapper diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 53fa011185..f590ba2033 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -23,8 +23,13 @@ from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union -from flytekit.common.tasks.sdk_runnable import ExecutionParameters -from flytekit.core.context_manager import FlyteContext, FlyteContextManager, FlyteEntities, SerializationSettings +from flytekit.core.context_manager import ( + ExecutionParameters, + FlyteContext, + FlyteContextManager, + FlyteEntities, + SerializationSettings, +) from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.local_cache import LocalTaskCache from flytekit.core.promise import ( diff --git a/flytekit/common/constants.py b/flytekit/core/constants.py similarity index 100% rename from flytekit/common/constants.py rename to flytekit/core/constants.py diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 75c6449c07..d46057c623 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,11 +1,11 @@ from enum import Enum from typing import Any, Dict, List, Optional, Type -from flytekit.common.tasks.raw_container import _get_container_definition from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.context_manager import SerializationSettings from flytekit.core.interface import Interface from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.utils import _get_container_definition from flytekit.models import task as _task_model diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 627cf3ea11..1db5e11d5e 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -23,19 +23,21 @@ import typing from contextlib import contextmanager from dataclasses import dataclass, field +from datetime import datetime from enum import Enum from typing import Any, Dict, Generator, List, Optional, Union from docker_image import reference from flytekit.clients import friendly as friendly_client # noqa -from flytekit.common.core.identifier import WorkflowExecutionIdentifier as _SdkWorkflowExecutionIdentifier -from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.configuration import images, internal from flytekit.configuration import sdk as _sdk_config +from flytekit.configuration import secrets +from flytekit.core import mock_stats, utils from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider from flytekit.core.node import Node -from flytekit.engines.unit import mock_stats as _mock_stats +from flytekit.interfaces.cli_identifiers import WorkflowExecutionIdentifier +from flytekit.interfaces.stats import taggable from flytekit.models.core import identifier as _identifier # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin @@ -132,6 +134,216 @@ def get_image_config(img_name: Optional[str] = None) -> ImageConfig: return ImageConfig(default_image=default_img, images=other_images) +class ExecutionParameters(object): + """ + This is a run-time user-centric context object that is accessible to every @task method. It can be accessed using + + .. code-block:: python + + flytekit.current_context() + + This object provides the following + * a statsd handler + * a logging handler + * the execution ID as an :py:class:`flytekit.models.core.identifier.WorkflowExecutionIdentifier` object + * a working directory for the user to write arbitrary files to + + Please do not confuse this object with the :py:class:`flytekit.FlyteContext` object. + """ + + @dataclass(init=False) + class Builder(object): + stats: taggable.TaggableStats + execution_date: datetime + logging: _logging + execution_id: str + attrs: typing.Dict[str, typing.Any] + working_dir: typing.Union[os.PathLike, utils.AutoDeletingTempDir] + + def __init__(self, current: typing.Optional[ExecutionParameters] = None): + self.stats = current.stats if current else None + self.execution_date = current.execution_date if current else None + self.working_dir = current.working_directory if current else None + self.execution_id = current.execution_id if current else None + self.logging = current.logging if current else None + self.attrs = current._attrs if current else {} + + def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: + self.attrs[key] = v + return self + + def build(self) -> ExecutionParameters: + if not isinstance(self.working_dir, utils.AutoDeletingTempDir): + pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True) + return ExecutionParameters( + execution_date=self.execution_date, + stats=self.stats, + tmp_dir=self.working_dir, + execution_id=self.execution_id, + logging=self.logging, + **self.attrs, + ) + + @staticmethod + def new_builder(current: ExecutionParameters = None) -> Builder: + return ExecutionParameters.Builder(current=current) + + def builder(self) -> Builder: + return ExecutionParameters.Builder(current=self) + + def __init__(self, execution_date, tmp_dir, stats, execution_id, logging, **kwargs): + """ + Args: + execution_date: Date when the execution is running + tmp_dir: temporary directory for the execution + stats: handle to emit stats + execution_id: Identifier for the xecution + logging: handle to logging + """ + self._stats = stats + self._execution_date = execution_date + self._working_directory = tmp_dir + self._execution_id = execution_id + self._logging = logging + # AutoDeletingTempDir's should be used with a with block, which creates upon entry + self._attrs = kwargs + # It is safe to recreate the Secrets Manager + self._secrets_manager = SecretsManager() + + @property + def stats(self) -> taggable.TaggableStats: + """ + A handle to a special statsd object that provides usefully tagged stats. + TODO: Usage examples and better comments + """ + return self._stats + + @property + def logging(self) -> _logging: + """ + A handle to a useful logging object. + TODO: Usage examples + """ + return self._logging + + @property + def working_directory(self) -> utils.AutoDeletingTempDir: + """ + A handle to a special working directory for easily producing temporary files. + + TODO: Usage examples + TODO: This does not always return a AutoDeletingTempDir + """ + return self._working_directory + + @property + def execution_date(self) -> datetime: + """ + This is a datetime representing the time at which a workflow was started. This is consistent across all tasks + executed in a workflow or sub-workflow. + + .. note:: + + Do NOT use this execution_date to drive any production logic. It might be useful as a tag for data to help + in debugging. + """ + return self._execution_date + + @property + def execution_id(self) -> str: + """ + This is the identifier of the workflow execution within the underlying engine. It will be consistent across all + task executions in a workflow or sub-workflow execution. + + .. note:: + + Do NOT use this execution_id to drive any production logic. This execution ID should only be used as a tag + on output data to link back to the workflow run that created it. + """ + return self._execution_id + + @property + def secrets(self) -> SecretsManager: + return self._secrets_manager + + def __getattr__(self, attr_name: str) -> typing.Any: + """ + This houses certain task specific context. For example in Spark, it houses the SparkSession, etc + """ + attr_name = attr_name.upper() + if self._attrs and attr_name in self._attrs: + return self._attrs[attr_name] + raise AssertionError(f"{attr_name} not available as a parameter in Flyte context - are you in right task-type?") + + def has_attr(self, attr_name: str) -> bool: + attr_name = attr_name.upper() + if self._attrs and attr_name in self._attrs: + return True + return False + + def get(self, key: str) -> typing.Any: + """ + Returns task specific context if present else raise an error. The returned context will match the key + """ + return self.__getattr__(attr_name=key) + + +class SecretsManager(object): + """ + This provides a secrets resolution logic at runtime. + The resolution order is + - Try env var first. The env var should have the configuration.SECRETS_ENV_PREFIX. The env var will be all upper + cased + - If not then try the file where the name matches lower case + ``configuration.SECRETS_DEFAULT_DIR//configuration.SECRETS_FILE_PREFIX`` + + All configuration values can always be overridden by injecting an environment variable + """ + + def __init__(self): + self._base_dir = str(secrets.SECRETS_DEFAULT_DIR.get()).strip() + self._file_prefix = str(secrets.SECRETS_FILE_PREFIX.get()).strip() + self._env_prefix = str(secrets.SECRETS_ENV_PREFIX.get()).strip() + + def get(self, group: str, key: str) -> str: + """ + Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError + """ + self.check_group_key(group, key) + env_var = self.get_secrets_env_var(group, key) + fpath = self.get_secrets_file(group, key) + v = os.environ.get(env_var) + if v is not None: + return v + if os.path.exists(fpath): + with open(fpath, "r") as f: + return f.read().strip() + raise ValueError( + f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}" + ) + + def get_secrets_env_var(self, group: str, key: str) -> str: + """ + Returns a string that matches the ENV Variable to look for the secrets + """ + self.check_group_key(group, key) + return f"{self._env_prefix}{group.upper()}_{key.upper()}" + + def get_secrets_file(self, group: str, key: str) -> str: + """ + Returns a path that matches the file to look for the secrets + """ + self.check_group_key(group, key) + return os.path.join(self._base_dir, group.lower(), f"{self._file_prefix}{key.lower()}") + + @staticmethod + def check_group_key(group: str, key: str): + if group is None or group == "": + raise ValueError("secrets group is a mandatory field.") + if key is None or key == "": + raise ValueError("secrets key is a mandatory field.") + + @dataclass class EntrypointSettings(object): """ @@ -684,9 +896,9 @@ def initialize(): # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users # are already acquainted with default_user_space_params = ExecutionParameters( - execution_id=str(_SdkWorkflowExecutionIdentifier.promote_from_model(default_execution_id)), + execution_id=str(WorkflowExecutionIdentifier.promote_from_model(default_execution_id)), execution_date=_datetime.datetime.utcnow(), - stats=_mock_stats.MockStats(), + stats=mock_stats.MockStats(), logging=_logging, tmp_dir=user_space_path, ) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index b121d053fa..00c233dd8b 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -32,8 +32,8 @@ from typing import Dict, Union from uuid import UUID -from flytekit.common.exceptions.user import FlyteAssertion -from flytekit.common.utils import PerformanceTimer +from flytekit.core.utils import PerformanceTimer +from flytekit.exceptions.user import FlyteAssertion from flytekit.interfaces.random import random from flytekit.loggers import logger diff --git a/flytekit/common/mixins/hash.py b/flytekit/core/hash.py similarity index 100% rename from flytekit/common/mixins/hash.py rename to flytekit/core/hash.py diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 263ac05fb7..b158a9434c 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -8,10 +8,10 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union -from flytekit.common.exceptions.user import FlyteValidationException from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions.user import FlyteValidationException from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.types.pickle import FlytePickle diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 42645a4797..29838cffcd 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -8,12 +8,12 @@ from itertools import count from typing import Any, Dict, List, Optional, Type -from flytekit.common.constants import SdkTaskType -from flytekit.common.exceptions import scopes as exception_scopes from flytekit.core.base_task import PythonTask +from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, SerializationSettings from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.exceptions import scopes as exception_scopes from flytekit.models.array_job import ArrayJob from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql diff --git a/flytekit/engines/unit/mock_stats.py b/flytekit/core/mock_stats.py similarity index 100% rename from flytekit/engines/unit/mock_stats.py rename to flytekit/core/mock_stats.py diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 0567e7b64a..cf2625b87a 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -4,8 +4,8 @@ import typing from typing import Any, List -from flytekit.common.utils import _dnsify from flytekit.core.resources import Resources +from flytekit.core.utils import _dnsify from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.task import Resources as _resources_model diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 120e9995ba..6eb4f51d81 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -3,13 +3,13 @@ import collections from typing import Type, Union -from flytekit.common.exceptions import user as _user_exceptions from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node from flytekit.core.promise import VoidPromise from flytekit.core.workflow import WorkflowBase +from flytekit.exceptions import user as _user_exceptions from flytekit.loggers import logger # This file exists instead of moving to node.py because it needs Task/Workflow/LaunchPlan and those depend on Node diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 7c9b780e19..17a4837432 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -7,8 +7,7 @@ from typing_extensions import Protocol -from flytekit.common import constants as _common_constants -from flytekit.common.exceptions import user as _user_exceptions +from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context from flytekit.core import interface as flyte_interface from flytekit.core import type_engine @@ -16,6 +15,7 @@ from flytekit.core.interface import Interface from flytekit.core.node import Node from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine +from flytekit.exceptions import user as _user_exceptions from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models import literals as _literals_models diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 689f5781dc..0226760f08 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -4,12 +4,12 @@ import re from typing import Callable, Dict, List, Optional, TypeVar -from flytekit.common.tasks.raw_container import _get_container_definition from flytekit.core.base_task import PythonTask, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance +from flytekit.core.utils import _get_container_definition from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index eaeb509d2e..c5a716c3cb 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -5,13 +5,12 @@ from flyteidl.core import tasks_pb2 as _tasks_pb2 -from flytekit.common import utils as common_utils -from flytekit.common.tasks.raw_container import _get_container_definition from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin from flytekit.core.context_manager import FlyteContext, Image, ImageConfig, SerializationSettings from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor from flytekit.core.tracker import TrackedInstance +from flytekit.core.utils import _get_container_definition, load_proto_from_file from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.core import identifier as identifier_models @@ -232,7 +231,7 @@ def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: ctx = FlyteContext.current_context() task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") # type: ignore ctx.file_access.get_data(loader_args[0], task_template_local_path) - task_template_proto = common_utils.load_proto_from_file(_tasks_pb2.TaskTemplate, task_template_local_path) + task_template_proto = load_proto_from_file(_tasks_pb2.TaskTemplate, task_template_local_path) task_template_model = _task_model.TaskTemplate.from_flyte_idl(task_template_proto) executor_class = load_object_from_module(loader_args[1]) diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 25a363b070..fa98b7ca89 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -19,7 +19,6 @@ from enum import Enum from typing import Any, Callable, List, Optional, TypeVar, Union -from flytekit.common.exceptions import scopes as exception_scopes from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring @@ -32,6 +31,7 @@ WorkflowMetadata, WorkflowMetadataDefaults, ) +from flytekit.exceptions import scopes as exception_scopes from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models @@ -178,7 +178,7 @@ def compile_into_workflow( with FlyteContextManager.with_context(ctx.with_compilation_state(cs)): # TODO: Resolve circular import - from flytekit.common.translator import get_serializable + from flytekit.tools.translator import get_serializable workflow_metadata = WorkflowMetadata(on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) defaults = WorkflowMetadataDefaults( diff --git a/flytekit/core/reference.py b/flytekit/core/reference.py index ed16ba6353..6a88549c43 100644 --- a/flytekit/core/reference.py +++ b/flytekit/core/reference.py @@ -2,10 +2,10 @@ from typing import Dict, Type -from flytekit.common.exceptions.user import FlyteValidationException from flytekit.core.launch_plan import ReferenceLaunchPlan from flytekit.core.task import ReferenceTask from flytekit.core.workflow import ReferenceWorkflow +from flytekit.exceptions.user import FlyteValidationException from flytekit.models.core import identifier as _identifier_model diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 22090838ff..42be1313cc 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Type, Union -from flytekit.common.exceptions import user as _user_exceptions from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.promise import ( @@ -14,6 +13,7 @@ translate_inputs_to_literals, ) from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions import user as _user_exceptions from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py index 637cfe4b47..ad6d39ef38 100644 --- a/flytekit/core/shim_task.py +++ b/flytekit/core/shim_task.py @@ -2,9 +2,10 @@ from typing import Any, Generic, Type, TypeVar, Union -from flytekit import ExecutionParameters, FlyteContext, FlyteContextManager, logger +from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine +from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models from flytekit.models import task as _task_model diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index cd2e8be02b..56f145b4b6 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -4,7 +4,7 @@ import logging as _logging from typing import Callable -from flytekit.common.exceptions import system as _system_exceptions +from flytekit.exceptions import system as _system_exceptions class InstanceTrackingMeta(type): diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9cca58b4bc..07faf178ce 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -25,10 +25,9 @@ from marshmallow_enum import EnumField, LoadDumpOptions from marshmallow_jsonschema import JSONSchema -from flytekit.common.exceptions import user as user_exceptions -from flytekit.common.types import primitives as _primitives from flytekit.core.context_manager import FlyteContext from flytekit.core.type_helpers import load_type_from_tag +from flytekit.exceptions import user as user_exceptions from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import types as _type_models @@ -257,7 +256,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: f"evaluation doesn't work with json dataclasses" ) - return _primitives.Generic.to_flyte_literal_type(metadata=schema) + return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if not dataclasses.is_dataclass(python_val): @@ -788,7 +787,7 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: return _type_models.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") - return _primitives.Generic.to_flyte_literal_type() + return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT) def to_literal( self, ctx: FlyteContext, python_val: typing.Any, python_type: Type[dict], expected: LiteralType @@ -1001,7 +1000,7 @@ def _register_default_type_transformers(): SimpleTransformer( "int", int, - _primitives.Integer.to_flyte_literal_type(), + _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER), lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))), lambda x: x.scalar.primitive.integer, ) @@ -1011,7 +1010,7 @@ def _register_default_type_transformers(): SimpleTransformer( "float", float, - _primitives.Float.to_flyte_literal_type(), + _type_models.LiteralType(simple=_type_models.SimpleType.FLOAT), lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))), _check_and_covert_float, ) @@ -1021,7 +1020,7 @@ def _register_default_type_transformers(): SimpleTransformer( "bool", bool, - _primitives.Boolean.to_flyte_literal_type(), + _type_models.LiteralType(simple=_type_models.SimpleType.BOOLEAN), lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))), lambda x: x.scalar.primitive.boolean, ) @@ -1031,7 +1030,7 @@ def _register_default_type_transformers(): SimpleTransformer( "str", str, - _primitives.String.to_flyte_literal_type(), + _type_models.LiteralType(simple=_type_models.SimpleType.STRING), lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x))), lambda x: x.scalar.primitive.string_value, ) @@ -1041,7 +1040,7 @@ def _register_default_type_transformers(): SimpleTransformer( "datetime", _datetime.datetime, - _primitives.Datetime.to_flyte_literal_type(), + _type_models.LiteralType(simple=_type_models.SimpleType.DATETIME), lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))), lambda x: x.scalar.primitive.datetime, ) @@ -1051,7 +1050,7 @@ def _register_default_type_transformers(): SimpleTransformer( "timedelta", _datetime.timedelta, - _primitives.Timedelta.to_flyte_literal_type(), + _type_models.LiteralType(simple=_type_models.SimpleType.DURATION), lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))), lambda x: x.scalar.primitive.duration, ) diff --git a/flytekit/common/utils.py b/flytekit/core/utils.py similarity index 57% rename from flytekit/common/utils.py rename to flytekit/core/utils.py index 1d0395fd90..71dd45f581 100644 --- a/flytekit/common/utils.py +++ b/flytekit/core/utils.py @@ -5,10 +5,10 @@ import time as _time from hashlib import sha224 as _sha224 from pathlib import Path +from typing import Dict, List -import flytekit as _flytekit -from flytekit.configuration import sdk as _sdk_config -from flytekit.models.core import identifier as _identifier +from flytekit.configuration import resources as _resource_config +from flytekit.models import task as _task_models def _dnsify(value: str) -> str: @@ -49,6 +49,84 @@ def _dnsify(value: str) -> str: return res +def _get_container_definition( + image: str, + command: List[str], + args: List[str], + data_loading_config: _task_models.DataLoadingConfig, + storage_request: str = None, + ephemeral_storage_request: str = None, + cpu_request: str = None, + gpu_request: str = None, + memory_request: str = None, + storage_limit: str = None, + ephemeral_storage_limit: str = None, + cpu_limit: str = None, + gpu_limit: str = None, + memory_limit: str = None, + environment: Dict[str, str] = None, +) -> _task_models.Container: + storage_limit = storage_limit or _resource_config.DEFAULT_STORAGE_LIMIT.get() + storage_request = storage_request or _resource_config.DEFAULT_STORAGE_REQUEST.get() + ephemeral_storage_limit = ephemeral_storage_limit or _resource_config.DEFAULT_EPHEMERAL_STORAGE_LIMIT.get() + ephemeral_storage_request = ephemeral_storage_request or _resource_config.DEFAULT_EPHEMERAL_STORAGE_REQUEST.get() + cpu_limit = cpu_limit or _resource_config.DEFAULT_CPU_LIMIT.get() + cpu_request = cpu_request or _resource_config.DEFAULT_CPU_REQUEST.get() + gpu_limit = gpu_limit or _resource_config.DEFAULT_GPU_LIMIT.get() + gpu_request = gpu_request or _resource_config.DEFAULT_GPU_REQUEST.get() + memory_limit = memory_limit or _resource_config.DEFAULT_MEMORY_LIMIT.get() + memory_request = memory_request or _resource_config.DEFAULT_MEMORY_REQUEST.get() + + requests = [] + if storage_request: + requests.append( + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) + ) + if ephemeral_storage_request: + requests.append( + _task_models.Resources.ResourceEntry( + _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request + ) + ) + if cpu_request: + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) + if gpu_request: + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) + if memory_request: + requests.append( + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) + ) + + limits = [] + if storage_limit: + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) + if ephemeral_storage_limit: + limits.append( + _task_models.Resources.ResourceEntry( + _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit + ) + ) + if cpu_limit: + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) + if gpu_limit: + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) + if memory_limit: + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) + + if environment is None: + environment = {} + + return _task_models.Container( + image=image, + command=command, + args=args, + resources=_task_models.Resources(limits=limits, requests=requests), + env=environment, + config={}, + data_loading_config=data_loading_config, + ) + + def load_proto_from_file(pb2_type, path): with open(path, "rb") as reader: out = pb2_type() @@ -62,10 +140,6 @@ def write_proto_to_file(proto, path): writer.write(proto.SerializeToString()) -def get_version_message(): - return "Welcome to Flyte! Version: {}".format(_flytekit.__version__) - - class Directory(object): def __init__(self, path): """ @@ -160,65 +234,3 @@ def __exit__(self, exc_type, exc_val, exc_tb): end_process_time - self._start_process_time, ) ) - - -class ExitStack(object): - def __init__(self, entered_stack=None): - self._contexts = entered_stack - - def enter_context(self, context): - out = context.__enter__() - self._contexts.append(context) - return out - - def pop_all(self): - entered_stack = self._contexts - self._contexts = None - return ExitStack(entered_stack=entered_stack) - - def __enter__(self): - if self._contexts is not None: - raise Exception("A non-empty context stack cannot be entered.") - self._contexts = [] - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - first_exception = None - if self._contexts is not None: - while len(self._contexts) > 0: - try: - self._contexts.pop().__exit__(exc_type, exc_val, exc_tb) - except Exception as ex: - # Catch all to try to clean up all exits before re-raising the first exception - if first_exception is None: - first_exception = ex - if first_exception is not None: - raise first_exception - return False - - -def fqdn(module, name, entity_type=None): - """ - :param Text module: - :param Text name: - :param int entity_type: _identifier.ResourceType enum - :rtype: Text - """ - fmt = _sdk_config.NAME_FORMAT.get() - if entity_type == _identifier.ResourceType.WORKFLOW: - fmt = _sdk_config.WORKFLOW_NAME_FORMAT.get() or fmt - elif entity_type == _identifier.ResourceType.TASK: - fmt = _sdk_config.TASK_NAME_FORMAT.get() or fmt - elif entity_type == _identifier.ResourceType.LAUNCH_PLAN: - fmt = _sdk_config.LAUNCH_PLAN_NAME_FORMAT.get() or fmt - return fmt.format(module=module, name=name) - - -def fqdn_safe(module, key, entity_type=None): - """ - :param Text module: - :param Text key: - :param int entity_type: _identifier.ResourceType enum - :rtype: Text - """ - return _dnsify(fqdn(module, key, entity_type=entity_type)) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index ffa6aae934..77d7a6a936 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -5,9 +5,7 @@ from functools import update_wrapper from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -from flytekit.common import constants as _common_constants -from flytekit.common.exceptions import scopes as exception_scopes -from flytekit.common.exceptions.user import FlyteValidationException, FlyteValueException +from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.condition import ConditionalSection @@ -34,6 +32,8 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions import scopes as exception_scopes +from flytekit.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models diff --git a/flytekit/engines/__init__.py b/flytekit/engines/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/engines/common.py b/flytekit/engines/common.py deleted file mode 100644 index ed1dfd9d37..0000000000 --- a/flytekit/engines/common.py +++ /dev/null @@ -1,422 +0,0 @@ -import abc as _abc - -from flytekit.models import common as _common_models - - -class BaseWorkflowExecutor(object, metaclass=_common_models.FlyteABCMeta): - """ - This class must be implemented for any engine to create, interact with, and execute workflows using the - FlyteKit SDK. - """ - - def __init__(self, sdk_workflow): - """ - :param flytekit.common.workflow.SdkWorkflow sdk_workflow: - """ - self._sdk_workflow = sdk_workflow - - @property - def sdk_workflow(self): - """ - :rtype: flytekit.common.workflow.SdkWorkflow - """ - return self._sdk_workflow - - @_abc.abstractmethod - def register(self, identifier): - """ - Registers the workflow - :param flytekit.models.core.identifier.Identifier identifier: - """ - pass - - -class BaseWorkflowExecution(object, metaclass=_common_models.FlyteABCMeta): - """ - This class must be implemented for any engine to track and interact with the executions of workflows. - """ - - def __init__(self, sdk_wf_exec): - """ - :param flytekit.common.workflow_execution.SdkWorkflowExecution sdk_wf_exec: - """ - self._sdk_wf_exec = sdk_wf_exec - - @property - def sdk_workflow_execution(self): - """ - :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution - """ - return self._sdk_wf_exec - - @_abc.abstractmethod - def get_node_executions(self, filters=None): - """ - :param list[flytekit.models.filters.Filter] filters: - :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] - """ - pass - - @_abc.abstractmethod - def sync(self): - """ - :rtype: None - """ - pass - - @_abc.abstractmethod - def get_inputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - pass - - @_abc.abstractmethod - def get_outputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - pass - - @_abc.abstractmethod - def terminate(self, cause): - """ - :param Text cause: - """ - pass - - -class BaseNodeExecution(object, metaclass=_common_models.FlyteABCMeta): - def __init__(self, node_execution): - """ - :param flytekit.common.nodes.SdkNodeExecution node_execution: - """ - self._sdk_node_execution = node_execution - - @property - def sdk_node_execution(self): - """ - :rtype: flytekit.common.nodes.SdkNodeExecution - """ - return self._sdk_node_execution - - @_abc.abstractmethod - def get_task_executions(self): - """ - :rtype: list[flytekit.common.tasks.executions.SdkTaskExecution] - """ - pass - - @_abc.abstractmethod - def get_subworkflow_executions(self): - """ - :rtype: list[flytekit.common.workflow_execution.SdkWorkflowExecution] - """ - pass - - @_abc.abstractmethod - def get_inputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - pass - - @_abc.abstractmethod - def get_outputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - pass - - @_abc.abstractmethod - def sync(self): - """ - :rtype: None - """ - pass - - -class BaseTaskExecution(object, metaclass=_common_models.FlyteABCMeta): - def __init__(self, task_exec): - """ - :param flytekit.common.tasks.executions.SdkTaskExecution task_exec: - """ - self._sdk_task_execution = task_exec - - @property - def sdk_task_execution(self): - """ - :rtype: flytekit.common.tasks.executions.SdkTaskExecution - """ - return self._sdk_task_execution - - @_abc.abstractmethod - def get_inputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - pass - - @_abc.abstractmethod - def get_outputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - pass - - @_abc.abstractmethod - def sync(self): - """ - :rtype: None - """ - pass - - @_abc.abstractmethod - def get_child_executions(self, filters=None): - """ - :param list[flytekit.models.filters.Filter] filters: - :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] - """ - pass - - -class BaseLaunchPlanLauncher(object, metaclass=_common_models.FlyteABCMeta): - def __init__(self, sdk_launch_plan): - """ - :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: - """ - self._sdk_launch_plan = sdk_launch_plan - - @property - def sdk_launch_plan(self): - """ - :rtype: flytekit.common.launch_plan.SdkLaunchPlan - """ - return self._sdk_launch_plan - - @_abc.abstractmethod - def register(self, identifier): - """ - Registers the launch plan - :param flytekit.models.core.identifier.Identifier identifier: - """ - pass - - @_abc.abstractmethod - def launch( - self, - project, - domain, - name, - inputs, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - ): - """ - Registers the launch plan and returns the identifier. - :param Text project: - :param Text domain: - :param Text name: - :param flytekit.models.literals.LiteralMap inputs: The inputs to pass - :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :rtype: flytekit.models.execution.Execution - """ - pass - - @_abc.abstractmethod - def update(self, identifier, state): - """ - :param flytekit.models.core.identifier.Identifier identifier: ID for launch plan to update - :param int state: Enum value from flytekit.models.launch_plan.LaunchPlanState - """ - pass - - -class BaseTaskExecutor(object, metaclass=_common_models.FlyteABCMeta): - def __init__(self, sdk_task): - """ - :param flytekit.common.tasks.task.SdkTask sdk_task: - """ - self._sdk_task = sdk_task - - @property - def sdk_task(self): - """ - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - return self._sdk_task - - @_abc.abstractmethod - def execute(self, inputs, context=None): - """ - :param flytekit.models.literals.LiteralMap inputs: Inputs to pass to the workflow. - """ - pass - - @_abc.abstractmethod - def register(self, identifier): - """ - Registers the task - :param flytekit.models.core.identifier.Identifier identifier: - """ - pass - - @_abc.abstractmethod - def launch( - self, - project, - domain, - name=None, - inputs=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - auth_role=None, - ): - """ - Executes the task as a single task execution and returns the identifier. - :param Text project: - :param Text domain: - :param Text name: - :param flytekit.models.literals.LiteralMap inputs: The inputs to pass - :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: - :rtype: flytekit.models.execution.Execution - """ - pass - - -class BaseExecutionEngineFactory(object, metaclass=_common_models.FlyteABCMeta): - """ - This object should be implemented to satisfy the basic engine interface. - """ - - @_abc.abstractmethod - def get_task(self, sdk_task): - """ - :param flytekit.common.tasks.task.SdkTask sdk_task: - :rtype: BaseTaskExecutor - """ - pass - - @_abc.abstractmethod - def get_launch_plan(self, sdk_launch_plan): - """ - :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: - :rtype: BaseLaunchPlanLauncher - """ - pass - - @_abc.abstractmethod - def get_task_execution(self, task_exec): - """ - :param flytekit.common.tasks.executions.SdkTaskExecution task_exec: - :rtype: BaseTaskExecution - """ - pass - - @_abc.abstractmethod - def get_node_execution(self, node_exec): - """ - :param flytekit.common.nodes.SdkNodeExecution node_exec: - :rtype: BaseNodeExecution - """ - pass - - @_abc.abstractmethod - def get_workflow_execution(self, wf_exec): - """ - :param flytekit.common.workflow_execution.SdkWorkflowExecution wf_exec: - :rtype: BaseWorkflowExecution - """ - pass - - @_abc.abstractmethod - def fetch_workflow_execution(self, wf_exec_id): - """ - :param flytekit.models.core.identifier.WorkflowExecutionIdentifier wf_exec_id: - :rtype: flytekit.models.execution.Execution - """ - pass - - @_abc.abstractmethod - def fetch_task(self, task_id): - """ - :param flytekit.models.core.identifier.Identifier task_id: This identifier should have a resource type of kind - Task. - :rtype: flytekit.models.task.Task - """ - pass - - @_abc.abstractmethod - def fetch_latest_task(self, named_task): - """ - Fetches the latest task - :param flytekit.models.common.NamedEntityIdentifier named_task: NamedEntityIdentifier to fetch - :rtype: flytekit.models.task.Task - """ - pass - - -class EngineContext(object): - def __init__( - self, - execution_date, - tmp_dir, - stats, - execution_id, - logging, - raw_output_data_prefix=None, - ): - self._stats = stats - self._execution_date = execution_date - self._working_directory = tmp_dir - self._execution_id = execution_id - self._logging = logging - self._raw_output_data_prefix = raw_output_data_prefix - - @property - def stats(self): - """ - :rtype: flytekit.interfaces.stats.taggable.TaggableStats - """ - return self._stats - - @property - def logging(self): - """ - :rtype: TODO - """ - return self._logging - - @property - def working_directory(self): - """ - :rtype: flytekit.common.utils.AutoDeletingTempDir - """ - return self._working_directory - - @property - def execution_date(self): - """ - :rtype: datetime.datetime - """ - return self._execution_date - - @property - def execution_id(self): - """ - :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier - """ - return self._execution_id - - @property - def raw_output_data_prefix(self) -> str: - return self._raw_output_data_prefix diff --git a/flytekit/engines/flyte/__init__.py b/flytekit/engines/flyte/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py deleted file mode 100644 index 7af35e20b9..0000000000 --- a/flytekit/engines/flyte/engine.py +++ /dev/null @@ -1,727 +0,0 @@ -import logging as _logging -import os as _os -import traceback as _traceback -from datetime import datetime as _datetime - -import six as _six -from deprecated import deprecated as _deprecated -from flyteidl.core import literals_pb2 as _literals_pb2 - -import flytekit -from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient -from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions -from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions -from flytekit.common import constants as _constants -from flytekit.common import utils as _common_utils -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.configuration import auth as _auth_config -from flytekit.configuration import internal as _internal_config -from flytekit.configuration import platform as _platform_config -from flytekit.configuration import sdk as _sdk_config -from flytekit.engines import common as _common_engine -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.interfaces.stats.taggable import get_stats as _get_stats -from flytekit.models import common as _common_models -from flytekit.models import execution as _execution_models -from flytekit.models import literals as _literals -from flytekit.models import task as _task_models -from flytekit.models.admin import common as _common -from flytekit.models.admin import workflow as _workflow_model -from flytekit.models.core import errors as _error_models -from flytekit.models.core import identifier as _identifier - - -class _FlyteClientManager(object): - _CLIENT = None - - def __init__(self, *args, **kwargs): - # TODO: React to changing configs. For now this is frozen for the lifetime of the process, which covers most - # TODO: use cases. - if type(self)._CLIENT is None: - c = _SynchronousFlyteClient(*args, **kwargs) - type(self)._CLIENT = c - - @property - def client(self): - """ - :rtype: flytekit.clients.friendly.SynchronousFlyteClient - """ - return type(self)._CLIENT - - -# This is a simple helper function that ties the client together with the configuration construct. -# This will be refactored away when we move to a heavier context object. -def get_client() -> _SynchronousFlyteClient: - return _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - - -class FlyteEngineFactory(_common_engine.BaseExecutionEngineFactory): - def get_workflow(self, sdk_workflow): - """ - :param flytekit.common.workflow.SdkWorkflow sdk_workflow: - :rtype: FlyteWorkflow - """ - return FlyteWorkflow(sdk_workflow) - - def get_task(self, sdk_task): - """ - :param flytekit.common.tasks.task.SdkTask sdk_task: - :rtype: FlyteTask - """ - return FlyteTask(sdk_task) - - def get_launch_plan(self, sdk_launch_plan): - """ - :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: - :rtype: FlyteLaunchPlan - """ - return FlyteLaunchPlan(sdk_launch_plan) - - def get_task_execution(self, task_exec): - """ - :param flytekit.common.tasks.executions.SdkTaskExecution task_exec: - :rtype: FlyteTaskExecution - """ - return FlyteTaskExecution(task_exec) - - def get_node_execution(self, node_exec): - """ - :param flytekit.common.nodes.SdkNodeExecution node_exec: - :rtype: FlyteNodeExecution - """ - return FlyteNodeExecution(node_exec) - - def get_workflow_execution(self, wf_exec): - """ - :param flytekit.common.workflow_execution.SdkWorkflowExecution wf_exec: - :rtype: FlyteWorkflowExecution - """ - return FlyteWorkflowExecution(wf_exec) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def fetch_workflow_execution(self, wf_exec_id): - """ - :param flytekit.models.core.identifier.WorkflowExecutionIdentifier wf_exec_id: - :rtype: flytekit.models.execution.Execution - """ - return _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.get_execution(wf_exec_id) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def fetch_task(self, task_id): - """ - Queries Admin for an existing Admin task - :param flytekit.models.core.identifier.Identifier task_id: - :rtype: flytekit.models.task.Task - """ - return _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.get_task(task_id) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def fetch_latest_task(self, named_task): - """ - Fetches the latest task - :param flytekit.models.common.NamedEntityIdentifier named_task: NamedEntityIdentifier to fetch - :rtype: flytekit.models.task.Task - """ - task_list, _ = _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.list_tasks_paginated( - named_task, - limit=1, - sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), - ) - return task_list[0] if task_list else None - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def fetch_launch_plan(self, launch_plan_id): - """ - :param flytekit.models.core.identifier.Identifier launch_plan_id: This identifier should have a resource - type of kind LaunchPlan. - :rtype: flytekit.models.launch_plan.LaunchPlan - """ - if launch_plan_id.version: - return _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.get_launch_plan(launch_plan_id) - else: - named_entity_id = _common_models.NamedEntityIdentifier( - launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name - ) - return _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.get_active_launch_plan(named_entity_id) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def fetch_workflow(self, workflow_id): - """ - :param flytekit.models.core.identifier.Identifier workflow_id: This identifier should have a resource - type of kind LaunchPlan. - :rtype: flytekit.models.admin.workflow.Workflow - """ - return _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.get_workflow(workflow_id) - - -class FlyteLaunchPlan(_common_engine.BaseLaunchPlanLauncher): - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def register(self, identifier): - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - try: - client.create_launch_plan(identifier, self.sdk_launch_plan) - except _user_exceptions.FlyteEntityAlreadyExistsException: - pass - - @_deprecated(reason="Use launch instead", version="0.9.0") - def execute( - self, - project, - domain, - name, - inputs, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - ): - """ - Deprecated. Use launch instead. - """ - return self.launch( - project, - domain, - name, - inputs, - notification_overrides, - label_overrides, - annotation_overrides, - ) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def launch( - self, - project, - domain, - name, - inputs, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - ): - """ - Creates a workflow execution using parameters specified in the launch plan. - :param Text project: - :param Text domain: - :param Text name: - :param flytekit.models.literals.LiteralMap inputs: - :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :rtype: flytekit.models.execution.Execution - """ - disable_all = notification_overrides == [] - if disable_all: - notification_overrides = None - else: - notification_overrides = _execution_models.NotificationList(notification_overrides or []) - disable_all = None - - try: - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - exec_id = client.create_execution( - project, - domain, - name, - _execution_models.ExecutionSpec( - self.sdk_launch_plan.id, - _execution_models.ExecutionMetadata( - _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - "sdk", # TODO: get principle - 0, # TODO: Detect nesting - ), - notifications=notification_overrides, - disable_all=disable_all, - labels=label_overrides, - annotations=annotation_overrides, - ), - inputs, - ) - except _user_exceptions.FlyteEntityAlreadyExistsException: - exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) - return client.get_execution(exec_id) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def update(self, identifier, state): - """ - :param flytekit.models.core.identifier.Identifier identifier: Identifier for launch plan to update - :param int state: Enum value from flytekit.models.launch_plan.LaunchPlanState - """ - return _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.update_launch_plan(identifier, state) - - -class FlyteWorkflow(_common_engine.BaseWorkflowExecutor): - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def register(self, identifier): - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - try: - sub_workflows = self.sdk_workflow.get_sub_workflows() - return client.create_workflow( - identifier, - _workflow_model.WorkflowSpec( - self.sdk_workflow, - sub_workflows, - ), - ) - except _user_exceptions.FlyteEntityAlreadyExistsException: - pass - - -class FlyteTask(_common_engine.BaseTaskExecutor): - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def register(self, identifier): - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - try: - client.create_task(identifier, _task_models.TaskSpec(self.sdk_task)) - except _user_exceptions.FlyteEntityAlreadyExistsException: - pass - - def execute(self, inputs, context=None): - """ - Just execute the task and write the outputs to where they belong - :param flytekit.models.literals.LiteralMap inputs: - :param dict[Text, Text] context: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] - """ - with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir: - with _common_utils.AutoDeletingTempDir("task_dir") as task_dir: - with _data_proxy.LocalWorkingDirectoryContext(task_dir): - raw_output_data_prefix = context.get("raw_output_data_prefix", None) - with _data_proxy.RemoteDataContext(raw_output_data_prefix_override=raw_output_data_prefix): - output_file_dict = dict() - - # This sets the logging level for user code and is the only place an sdk setting gets - # used at runtime. Optionally, Propeller can set an internal config setting which - # takes precedence. - log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get() - _logging.getLogger().setLevel(log_level) - - try: - output_file_dict = self.sdk_task.execute( - _common_engine.EngineContext( - execution_id=_identifier.WorkflowExecutionIdentifier( - project=_internal_config.EXECUTION_PROJECT.get(), - domain=_internal_config.EXECUTION_DOMAIN.get(), - name=_internal_config.EXECUTION_NAME.get(), - ), - execution_date=_datetime.utcnow(), - stats=_get_stats( - # Stats metric path will be: - # registration_project.registration_domain.app.module.task_name.user_stats - # and it will be tagged with execution-level values for project/domain/wf/lp - "{}.{}.{}.user_stats".format( - _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), - _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), - _internal_config.TASK_NAME.get() or _internal_config.NAME.get(), - ), - tags={ - "exec_project": _internal_config.EXECUTION_PROJECT.get(), - "exec_domain": _internal_config.EXECUTION_DOMAIN.get(), - "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(), - "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(), - "api_version": flytekit.__version__, - }, - ), - logging=_logging, - tmp_dir=task_dir, - raw_output_data_prefix=context["raw_output_data_prefix"] - if "raw_output_data_prefix" in context - else None, - ), - inputs, - ) - except _exception_scopes.FlyteScopedException as e: - _logging.error("!!! Begin Error Captured by Flyte !!!") - output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( - _error_models.ContainerError(e.error_code, e.verbose_message, e.kind, 0) - ) - _logging.error(e.verbose_message) - _logging.error("!!! End Error Captured by Flyte !!!") - except Exception: - _logging.error("!!! Begin Unknown System Error Captured by Flyte !!!") - exc_str = _traceback.format_exc() - output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( - _error_models.ContainerError( - "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE, 0 - ) - ) - _logging.error(exc_str) - _logging.error("!!! End Error Captured by Flyte !!!") - finally: - for k, v in _six.iteritems(output_file_dict): - _common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(temp_dir.name, k)) - _data_proxy.Data.put_data( - temp_dir.name, - context["output_prefix"], - is_multipart=True, - ) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def launch( - self, - project, - domain, - name=None, - inputs=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - auth_role=None, - ): - """ - Executes the task as a single task execution and returns the identifier. - :param Text project: - :param Text domain: - :param Text name: - :param flytekit.models.literals.LiteralMap inputs: The inputs to pass - :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the - notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: - :rtype: flytekit.models.execution.Execution - """ - disable_all = notification_overrides == [] - if disable_all: - notification_overrides = None - else: - notification_overrides = _execution_models.NotificationList(notification_overrides or []) - disable_all = None - - if not auth_role: - assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() - kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() - - if not (assumable_iam_role or kubernetes_service_account): - _logging.warning( - "Using deprecated `role` from config. " - "Please update your config to use `assumable_iam_role` instead" - ) - assumable_iam_role = _sdk_config.ROLE.get() - auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account, - ) - - try: - # TODO(katrogan): Add handling to register the underlying task if it's not already. - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - exec_id = client.create_execution( - project, - domain, - name, - _execution_models.ExecutionSpec( - self.sdk_task.id, - _execution_models.ExecutionMetadata( - _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - "sdk", # TODO: get principle - 0, # TODO: Detect nesting - ), - notifications=notification_overrides, - disable_all=disable_all, - labels=label_overrides, - annotations=annotation_overrides, - auth_role=auth_role, - ), - inputs, - ) - except _user_exceptions.FlyteEntityAlreadyExistsException: - exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) - return client.get_execution(exec_id) - - -class FlyteWorkflowExecution(_common_engine.BaseWorkflowExecution): - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_node_executions(self, filters=None): - """ - :param list[flytekit.models.filters.Filter] filters: - :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - return { - v.id.node_id: v for v in _iterate_node_executions(client, self.sdk_workflow_execution.id, filters=filters) - } - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def sync(self): - """ - :rtype: None - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - self.sdk_workflow_execution._closure = client.get_execution(self.sdk_workflow_execution.id).closure - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_inputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - execution_data = client.get_execution_data(self.sdk_workflow_execution.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_inputs.literals): - return execution_data.full_inputs - - if execution_data.inputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) - return _literals.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - return _literals.LiteralMap({}) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_outputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - execution_data = client.get_execution_data(self.sdk_workflow_execution.id) - - # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_outputs.literals): - return execution_data.full_outputs - - if execution_data.outputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) - return _literals.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - return _literals.LiteralMap({}) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def terminate(self, cause): - """ - :param Text cause: - """ - _FlyteClientManager( - _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() - ).client.terminate_execution(self.sdk_workflow_execution.id, cause) - - -class FlyteNodeExecution(_common_engine.BaseNodeExecution): - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_task_executions(self): - """ - :rtype: list[flytekit.common.tasks.executions.SdkTaskExecution] - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - return list(_iterate_task_executions(client, self.sdk_node_execution.id)) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_subworkflow_executions(self): - """ - :rtype: list[flytekit.common.workflow_execution.SdkWorkflowExecution] - """ - raise NotImplementedError("Cannot retrieve sub-workflow information from a node execution yet.") - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_inputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - execution_data = client.get_node_execution_data(self.sdk_node_execution.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_inputs.literals): - return execution_data.full_inputs - - if execution_data.inputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) - return _literals.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - return _literals.LiteralMap({}) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_outputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - execution_data = client.get_node_execution_data(self.sdk_node_execution.id) - - # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_outputs.literals): - return execution_data.full_outputs - - if execution_data.outputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) - return _literals.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - return _literals.LiteralMap({}) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def sync(self): - """ - :rtype: None - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - self.sdk_node_execution._closure = client.get_node_execution(self.sdk_node_execution.id).closure - - -class FlyteTaskExecution(_common_engine.BaseTaskExecution): - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_inputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - execution_data = client.get_task_execution_data(self.sdk_task_execution.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_inputs.literals): - return execution_data.full_inputs - - if execution_data.inputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) - return _literals.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - return _literals.LiteralMap({}) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_outputs(self): - """ - :rtype: flytekit.models.literals.LiteralMap - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - execution_data = client.get_task_execution_data(self.sdk_task_execution.id) - - # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - if bool(execution_data.full_outputs.literals): - return execution_data.full_outputs - - if execution_data.outputs.bytes > 0: - with _common_utils.AutoDeletingTempDir() as t: - tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) - return _literals.LiteralMap.from_flyte_idl( - _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) - ) - return _literals.LiteralMap({}) - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def sync(self): - """ - :rtype: None - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - self.sdk_task_execution._closure = client.get_task_execution(self.sdk_task_execution.id).closure - - @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", - version="0.13.0", - ) - def get_child_executions(self, filters=None): - """ - :param list[flytekit.models.filters.Filter] filters: - :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] - """ - client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - return { - v.id.node_id: v - for v in _iterate_node_executions( - client, - task_execution_identifier=self.sdk_task_execution.id, - filters=filters, - ) - } diff --git a/flytekit/engines/loader.py b/flytekit/engines/loader.py deleted file mode 100644 index 336d5b9de6..0000000000 --- a/flytekit/engines/loader.py +++ /dev/null @@ -1,44 +0,0 @@ -import importlib as _importlib - -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.configuration import sdk as _sdk_config - -_ENGINE_NAME_TO_MODULES_CACHE = { - "flyte": ("flytekit.engines.flyte.engine", "FlyteEngineFactory", None), - "unit": ("flytekit.engines.unit.engine", "UnitTestEngineFactory", None), - # 'local': ('flytekit.engines.local.engine', 'EngineObjectFactory', None) -} - - -def get_engine(engine_name=None): - """ - :param Text engine_name: - :rtype: flytekit.engines.common.BaseExecutionEngineFactory - """ - engine_name = engine_name or _sdk_config.EXECUTION_ENGINE.get() - - # TODO: Allow users to plug-in their own engine code via a config - if engine_name not in _ENGINE_NAME_TO_MODULES_CACHE: - raise _user_exceptions.FlyteValueException( - engine_name, - "Could not load an engine with the identifier '{}'. Known engines are: {}".format( - engine_name, list(_ENGINE_NAME_TO_MODULES_CACHE.keys()) - ), - ) - - module_path, attr, engine_impl = _ENGINE_NAME_TO_MODULES_CACHE[engine_name] - if engine_impl is None: - module = _exception_scopes.user_entry_point(_importlib.import_module)(module_path) - - if not hasattr(module, attr): - raise _user_exceptions.FlyteValueException( - module, - "Failed to load the engine because the attribute named '{}' could not be found" - "in the module '{}'.".format(attr, module_path), - ) - - engine_impl = getattr(module, attr)() - _ENGINE_NAME_TO_MODULES_CACHE[engine_name] = (module_path, attr, engine_impl) - - return engine_impl diff --git a/flytekit/engines/unit/__init__.py b/flytekit/engines/unit/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/engines/unit/engine.py b/flytekit/engines/unit/engine.py deleted file mode 100644 index 2e3c4e187b..0000000000 --- a/flytekit/engines/unit/engine.py +++ /dev/null @@ -1,329 +0,0 @@ -import logging as _logging -import os as _os -from datetime import datetime as _datetime - -import six as _six -from flyteidl.plugins import qubole_pb2 as _qubole_pb2 -from google.protobuf.json_format import ParseDict as _ParseDict -from six import moves as _six_moves - -from flytekit.common import constants as _sdk_constants -from flytekit.common import utils as _common_utils -from flytekit.common.exceptions import system as _system_exception -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration -from flytekit.engines import common as _common_engine -from flytekit.engines.unit.mock_stats import MockStats -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import array_job as _array_job -from flytekit.models import literals as _literals -from flytekit.models import qubole as _qubole_models -from flytekit.models.core.identifier import WorkflowExecutionIdentifier - - -class UnitTestEngineFactory(_common_engine.BaseExecutionEngineFactory): - def get_task(self, sdk_task): - """ - :param flytekit.common.tasks.task.SdkTask sdk_task: - :rtype: UnitTestEngineTask - """ - if sdk_task.type in { - _sdk_constants.SdkTaskType.PYTHON_TASK, - _sdk_constants.SdkTaskType.SPARK_TASK, - _sdk_constants.SdkTaskType.SENSOR_TASK, - }: - return ReturnOutputsTask(sdk_task) - elif sdk_task.type in { - _sdk_constants.SdkTaskType.DYNAMIC_TASK, - }: - return DynamicTask(sdk_task) - elif sdk_task.type in { - _sdk_constants.SdkTaskType.BATCH_HIVE_TASK, - }: - return HiveTask(sdk_task) - else: - raise _user_exceptions.FlyteAssertion( - "Unit tests are not currently supported for tasks of type: {}".format(sdk_task.type) - ) - - def get_workflow(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing of workflows is not currently supported") - - def get_launch_plan(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing of launch plans is not currently supported") - - def get_task_execution(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing does not return execution handles.") - - def get_node_execution(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing does not return execution handles.") - - def get_workflow_execution(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing does not return execution handles.") - - def fetch_workflow_execution(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing does not fetch execution handles.") - - def fetch_task(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing does not fetch real tasks.") - - def fetch_latest_task(self, named_task): - raise _user_exceptions.FlyteAssertion("Unit testing does not fetch the real latest task.") - - def fetch_launch_plan(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing does not fetch real launch plans.") - - def fetch_workflow(self, _): - raise _user_exceptions.FlyteAssertion("Unit testing does not fetch real workflows.") - - -class UnitTestEngineTask(_common_engine.BaseTaskExecutor): - def execute(self, inputs, context=None): - """ - Just execute the function and return the outputs as a user-readable dictionary. - :param flytekit.models.literals.LiteralMap inputs: - :param context: - :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] - """ - with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(__file__), "unit.config"), - internal_overrides={"image": "unit_image"}, - ): - with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory: - with _data_proxy.LocalWorkingDirectoryContext(working_directory): - return self._transform_for_user_output(self._execute_user_code(inputs)) - - def _execute_user_code(self, inputs): - """ - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] - """ - with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: - return self.sdk_task.execute( - _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), - execution_date=_datetime.utcnow(), - stats=MockStats(), - logging=_logging, # TODO: A mock logging object that we can read later. - tmp_dir=user_working_directory, - ), - inputs, - ) - - def _transform_for_user_output(self, outputs): - """ - Take whatever is returned from the task execution and convert to a reasonable output for the behavior of this - task's unit test. - :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: - :rtype: T - """ - return outputs - - def register(self, identifier, version): - raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.") - - def launch( - self, - project, - domain, - name=None, - inputs=None, - notification_overrides=None, - label_overrides=None, - annotation_overrides=None, - auth_role=None, - ): - raise _user_exceptions.FlyteAssertion("You cannot launch unit test tasks.") - - -class ReturnOutputsTask(UnitTestEngineTask): - def _transform_for_user_output(self, outputs): - """ - Just return the outputs as a user-readable dictionary. - :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: - :rtype: T - """ - literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME] - return { - name: _type_helpers.get_sdk_value_from_literal( - literal_map.literals[name], - sdk_type=_type_helpers.get_sdk_type_from_literal_type(variable.type), - ).to_python_std() - for name, variable in _six.iteritems(self.sdk_task.interface.outputs) - } - - -class DynamicTask(ReturnOutputsTask): - def __init__(self, *args, **kwargs): - self._has_workflow_node = False - super(DynamicTask, self).__init__(*args, **kwargs) - - def _transform_for_user_output(self, outputs): - if self.has_workflow_node: - # If a workflow node has been detected, then we skip any transformation - # This is to support the early termination behavior of the unit test engine when it comes to dynamic tasks - # that produce launch plan or subworkflow nodes. - # See the warning message in the code below for additional information - return outputs - return super(DynamicTask, self)._transform_for_user_output(outputs) - - def _execute_user_code(self, inputs): - """ - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] - """ - results = super(DynamicTask, self)._execute_user_code(inputs) - if _sdk_constants.FUTURES_FILE_NAME in results: - futures = results[_sdk_constants.FUTURES_FILE_NAME] - sub_task_outputs = {} - tasks_map = {task.id: task for task in futures.tasks} - - for future_node in futures.nodes: - if future_node.workflow_node is not None: - # TODO: implement proper unit testing for launchplan and subworkflow nodes somehow - _logging.warning( - "A workflow node has been detected in the output of the dynamic task. The " - "Flytekit unit test engine is incomplete for dynamic tasks that return launch " - "plans or subworkflows. The generated dynamic job spec will be returned but " - "they will not be run." - ) - # For now, just return the output of the parent task - self._has_workflow_node = True - return results - task = tasks_map[future_node.task_node.reference_id] - if task.type == _sdk_constants.SdkTaskType.CONTAINER_ARRAY_TASK: - sub_task_output = DynamicTask.execute_array_task(future_node.id, task, results) - elif task.type == _sdk_constants.SdkTaskType.SPARK_TASK: - # This is required because `_transform_for_user_output` function about is invoked which - # checks for outputs - self._has_workflow_node = True - return results - elif task.type == _sdk_constants.SdkTaskType.HIVE_JOB: - # TODO: futures.outputs should have the Schema instances. - # After schema is implemented, fill out random data into the random locations - # then check output in test function - # Even though we recommend people use typed schemas, they might not always do so... - # in which case it'll be impossible to predict the actual schema, we should support a - # way for unit test authors to provide fake data regardless - sub_task_output = None - else: - inputs_path = _os.path.join(future_node.id, _sdk_constants.INPUT_FILE_NAME) - if inputs_path not in results: - raise _system_exception.FlyteSystemAssertion( - "dynamic task hasn't generated expected inputs document [{}] found {}".format( - future_node.id, list(results.keys()) - ) - ) - sub_task_output = UnitTestEngineFactory().get_task(task).execute(results[inputs_path]) - sub_task_outputs[future_node.id] = sub_task_output - - results[_sdk_constants.OUTPUT_FILE_NAME] = _literals.LiteralMap( - literals={ - binding.var: DynamicTask.fulfil_bindings(binding.binding, sub_task_outputs) - for binding in futures.outputs - } - ) - return results - - @property - def has_workflow_node(self): - """ - :rtype: bool - """ - return self._has_workflow_node - - @staticmethod - def execute_array_task(root_input_path, task, array_inputs): - array_job = _array_job.ArrayJob.from_dict(task.custom) - outputs = {} - for job_index in _six_moves.range(0, array_job.size): - inputs_path = _os.path.join( - root_input_path, - _six.text_type(job_index), - _sdk_constants.INPUT_FILE_NAME, - ) - if inputs_path not in array_inputs: - raise _system_exception.FlyteSystemAssertion( - "dynamic task hasn't generated expected inputs document [{}].".format(inputs_path) - ) - - input_proto = array_inputs[inputs_path] - # All outputs generated by the same array job will have the same key in sub_task_outputs, - # they will, however, differ in the var names; they will be on the format []. - # e.g. [1].out1 - for key, val in _six.iteritems( - ReturnOutputsTask( - task.assign_type_and_return(_sdk_constants.SdkTaskType.PYTHON_TASK) # TODO: This is weird - ).execute(input_proto) - ): - outputs["[{}].{}".format(job_index, key)] = val - return outputs - - @staticmethod - def fulfil_bindings(binding_data, fulfilled_promises): - """ - Substitutes promise values in binding_data with model Literal values built from python std values in - fulfilled_promises - - :param _interface.BindingData binding_data: - :param dict[Text,T] fulfilled_promises: - :rtype: - """ - if binding_data.scalar: - return _literals.Literal(scalar=binding_data.scalar) - elif binding_data.collection: - return _literals.Literal( - collection=_literals.LiteralCollection( - [ - DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) - for sub_binding_data in binding_data.collection.bindings - ] - ) - ) - elif binding_data.promise: - if binding_data.promise.node_id not in fulfilled_promises: - raise _system_exception.FlyteSystemAssertion( - "Expecting output of node [{}] but that hasn't been produced.".format(binding_data.promise.node_id) - ) - node_output = fulfilled_promises[binding_data.promise.node_id] - if binding_data.promise.var not in node_output: - raise _system_exception.FlyteSystemAssertion( - "Expecting output [{}] of node [{}] but that hasn't been produced.".format( - binding_data.promise.var, binding_data.promise.node_id - ) - ) - - return binding_data.promise.sdk_type.from_python_std(node_output[binding_data.promise.var]) - elif binding_data.map: - return _literals.Literal( - map=_literals.LiteralMap( - { - k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) - for k, sub_binding_data in _six.iteritems(binding_data.map.bindings) - } - ) - ) - - -class HiveTask(DynamicTask): - def _transform_for_user_output(self, outputs): - """ - Just execute the function and return the list of Hive queries returned. - :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs: - :rtype: list[Text] - """ - futures = outputs.get(_sdk_constants.FUTURES_FILE_NAME) - if futures: - queries = [] - task_ids_to_defs = { - t.id.name: _qubole_models.QuboleHiveJob.from_flyte_idl( - _ParseDict(t.custom, _qubole_pb2.QuboleHiveJob()) - ) - for t in futures.tasks - } - for node in futures.nodes: - queries.append(task_ids_to_defs[node.task_node.reference_id.name].query.query) - return queries - else: - return [] diff --git a/flytekit/engines/unit/unit.config b/flytekit/engines/unit/unit.config deleted file mode 100644 index 8e2c00a83e..0000000000 --- a/flytekit/engines/unit/unit.config +++ /dev/null @@ -1,14 +0,0 @@ -[sdk] - -workflow_packages=this.module,that.module - -[auth] -assumable_iam_role=unit_test_role - -[container] - -image=some_docker_repo:some_image_name:tag - -[platform] - -url=unittest diff --git a/flytekit/common/__init__.py b/flytekit/exceptions/__init__.py similarity index 100% rename from flytekit/common/__init__.py rename to flytekit/exceptions/__init__.py diff --git a/flytekit/common/exceptions/base.py b/flytekit/exceptions/base.py similarity index 100% rename from flytekit/common/exceptions/base.py rename to flytekit/exceptions/base.py diff --git a/flytekit/common/exceptions/scopes.py b/flytekit/exceptions/scopes.py similarity index 97% rename from flytekit/common/exceptions/scopes.py rename to flytekit/exceptions/scopes.py index a32d9dbcc6..60a4afa97e 100644 --- a/flytekit/common/exceptions/scopes.py +++ b/flytekit/exceptions/scopes.py @@ -3,9 +3,9 @@ from wrapt import decorator as _decorator -from flytekit.common.exceptions import base as _base_exceptions -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common.exceptions import user as _user_exceptions +from flytekit.exceptions import base as _base_exceptions +from flytekit.exceptions import system as _system_exceptions +from flytekit.exceptions import user as _user_exceptions from flytekit.models.core import errors as _error_model diff --git a/flytekit/common/exceptions/system.py b/flytekit/exceptions/system.py similarity index 95% rename from flytekit/common/exceptions/system.py rename to flytekit/exceptions/system.py index 5802279590..63c43e8879 100644 --- a/flytekit/common/exceptions/system.py +++ b/flytekit/exceptions/system.py @@ -1,4 +1,4 @@ -from flytekit.common.exceptions import base as _base_exceptions +from flytekit.exceptions import base as _base_exceptions class FlyteSystemException(_base_exceptions.FlyteRecoverableException): diff --git a/flytekit/common/exceptions/user.py b/flytekit/exceptions/user.py similarity index 94% rename from flytekit/common/exceptions/user.py rename to flytekit/exceptions/user.py index acb5dd7997..a93510cae1 100644 --- a/flytekit/common/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -1,5 +1,5 @@ -from flytekit.common.exceptions.base import FlyteException as _FlyteException -from flytekit.common.exceptions.base import FlyteRecoverableException as _Recoverable +from flytekit.exceptions.base import FlyteException as _FlyteException +from flytekit.exceptions.base import FlyteRecoverableException as _Recoverable class FlyteUserException(_FlyteException): diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py index 2c22fd5bd9..d420310fa2 100644 --- a/flytekit/extend/__init__.py +++ b/flytekit/extend/__init__.py @@ -33,12 +33,11 @@ DataPersistencePlugins """ -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver -from flytekit.core.context_manager import ExecutionState, Image, ImageConfig, SerializationSettings +from flytekit.core.context_manager import ExecutionState, Image, ImageConfig, SecretsManager, SerializationSettings from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.core.interface import Interface from flytekit.core.promise import Promise @@ -46,3 +45,4 @@ from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor from flytekit.core.task import TaskPlugins from flytekit.core.type_engine import DictTransformer, T, TypeEngine, TypeTransformer +from flytekit.tools.translator import get_serializable diff --git a/flytekit/extras/persistence/gcs_gsutil.py b/flytekit/extras/persistence/gcs_gsutil.py index bb4ec31487..7e7711d64a 100644 --- a/flytekit/extras/persistence/gcs_gsutil.py +++ b/flytekit/extras/persistence/gcs_gsutil.py @@ -2,9 +2,9 @@ import typing from shutil import which as shell_which -from flytekit.common.exceptions.user import FlyteUserException from flytekit.configuration import gcp from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins +from flytekit.exceptions.user import FlyteUserException from flytekit.tools import subprocess diff --git a/flytekit/extras/persistence/http.py b/flytekit/extras/persistence/http.py index d9fa4674d7..30fa8d0f65 100644 --- a/flytekit/extras/persistence/http.py +++ b/flytekit/extras/persistence/http.py @@ -3,8 +3,8 @@ import requests -from flytekit.common.exceptions import user from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins +from flytekit.exceptions import user from flytekit.loggers import logger diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py index ddf26ec4d5..3b24fef94b 100644 --- a/flytekit/extras/persistence/s3_awscli.py +++ b/flytekit/extras/persistence/s3_awscli.py @@ -7,9 +7,9 @@ from shutil import which as shell_which from typing import Dict, List, Optional -from flytekit.common.exceptions.user import FlyteUserException from flytekit.configuration import aws from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins +from flytekit.exceptions.user import FlyteUserException from flytekit.tools import subprocess diff --git a/flytekit/common/core/identifier.py b/flytekit/interfaces/cli_identifiers.py similarity index 89% rename from flytekit/common/core/identifier.py rename to flytekit/interfaces/cli_identifiers.py index c7b12a5190..18cfad424b 100644 --- a/flytekit/common/core/identifier.py +++ b/flytekit/interfaces/cli_identifiers.py @@ -1,18 +1,15 @@ -import six as _six - -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions +from flytekit.exceptions import user as _user_exceptions from flytekit.models.core import identifier as _core_identifier -class Identifier(_core_identifier.Identifier, metaclass=_sdk_bases.ExtendedSdkType): +class Identifier(_core_identifier.Identifier): _STRING_TO_TYPE_MAP = { "lp": _core_identifier.ResourceType.LAUNCH_PLAN, "wf": _core_identifier.ResourceType.WORKFLOW, "tsk": _core_identifier.ResourceType.TASK, } - _TYPE_TO_STRING_MAP = {v: k for k, v in _six.iteritems(_STRING_TO_TYPE_MAP)} + _TYPE_TO_STRING_MAP = {v: k for k, v in _STRING_TO_TYPE_MAP.items()} @classmethod def promote_from_model(cls, base_model): @@ -28,6 +25,11 @@ def promote_from_model(cls, base_model): base_model.version, ) + @classmethod + def from_flyte_idl(cls, pb2_object): + base_model = super().from_flyte_idl(pb2_object) + return cls.promote_from_model(base_model) + @classmethod def from_python_std(cls, string): """ @@ -63,106 +65,116 @@ def __str__(self): ) -class WorkflowExecutionIdentifier(_core_identifier.WorkflowExecutionIdentifier, metaclass=_sdk_bases.ExtendedSdkType): +class TaskExecutionIdentifier(_core_identifier.TaskExecutionIdentifier): @classmethod def promote_from_model(cls, base_model): """ - :param flytekit.models.core.identifier.WorkflowExecutionIdentifier base_model: - :rtype: WorkflowExecutionIdentifier + :param flytekit.models.core.identifier.TaskExecutionIdentifier base_model: + :rtype: TaskExecutionIdentifier """ return cls( - base_model.project, - base_model.domain, - base_model.name, + task_id=base_model.task_id, + node_execution_id=base_model.node_execution_id, + retry_attempt=base_model.retry_attempt, ) + @classmethod + def from_flyte_idl(cls, pb2_object): + base_model = super().from_flyte_idl(pb2_object) + return cls.promote_from_model(base_model) + @classmethod def from_python_std(cls, string): """ Parses a string in the correct format into an identifier :param Text string: - :rtype: WorkflowExecutionIdentifier + :rtype: TaskExecutionIdentifier """ segments = string.split(":") - if len(segments) != 4: + if len(segments) != 10: raise _user_exceptions.FlyteValueException( string, "The provided string was not in a parseable format. The string for an identifier must be in the format" - " ex:project:domain:name.", + " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.", ) - resource_type, project, domain, name = segments + resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments - if resource_type != "ex": + if resource_type != "te": raise _user_exceptions.FlyteValueException( resource_type, "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", ) return cls( - project, - domain, - name, + task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), + node_execution_id=_core_identifier.NodeExecutionIdentifier( + node_id=node_id, + execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), + ), + retry_attempt=int(retry), ) def __str__(self): - return "ex:{}:{}:{}".format(self.project, self.domain, self.name) + return "te:{ep}:{ed}:{en}:{node_id}:{tp}:{td}:{tn}:{tv}:{retry}".format( + ep=self.node_execution_id.execution_id.project, + ed=self.node_execution_id.execution_id.domain, + en=self.node_execution_id.execution_id.name, + node_id=self.node_execution_id.node_id, + tp=self.task_id.project, + td=self.task_id.domain, + tn=self.task_id.name, + tv=self.task_id.version, + retry=self.retry_attempt, + ) -class TaskExecutionIdentifier(_core_identifier.TaskExecutionIdentifier, metaclass=_sdk_bases.ExtendedSdkType): +class WorkflowExecutionIdentifier(_core_identifier.WorkflowExecutionIdentifier): @classmethod def promote_from_model(cls, base_model): """ - :param flytekit.models.core.identifier.TaskExecutionIdentifier base_model: - :rtype: TaskExecutionIdentifier + :param flytekit.models.core.identifier.WorkflowExecutionIdentifier base_model: + :rtype: WorkflowExecutionIdentifier """ return cls( - task_id=base_model.task_id, - node_execution_id=base_model.node_execution_id, - retry_attempt=base_model.retry_attempt, + base_model.project, + base_model.domain, + base_model.name, ) + @classmethod + def from_flyte_idl(cls, pb2_object): + base_model = super().from_flyte_idl(pb2_object) + return cls.promote_from_model(base_model) + @classmethod def from_python_std(cls, string): """ Parses a string in the correct format into an identifier :param Text string: - :rtype: TaskExecutionIdentifier + :rtype: WorkflowExecutionIdentifier """ segments = string.split(":") - if len(segments) != 10: + if len(segments) != 4: raise _user_exceptions.FlyteValueException( string, "The provided string was not in a parseable format. The string for an identifier must be in the format" - " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.", + " ex:project:domain:name.", ) - resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments + resource_type, project, domain, name = segments - if resource_type != "te": + if resource_type != "ex": raise _user_exceptions.FlyteValueException( resource_type, "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", ) return cls( - task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), - node_execution_id=_core_identifier.NodeExecutionIdentifier( - node_id=node_id, - execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), - ), - retry_attempt=int(retry), + project, + domain, + name, ) def __str__(self): - return "te:{ep}:{ed}:{en}:{node_id}:{tp}:{td}:{tn}:{tv}:{retry}".format( - ep=self.node_execution_id.execution_id.project, - ed=self.node_execution_id.execution_id.domain, - en=self.node_execution_id.execution_id.name, - node_id=self.node_execution_id.node_id, - tp=self.task_id.project, - td=self.task_id.domain, - tn=self.task_id.name, - tv=self.task_id.version, - retry=self.retry_attempt, - ) + return "ex:{}:{}:{}".format(self.project, self.domain, self.name) diff --git a/flytekit/interfaces/data/__init__.py b/flytekit/interfaces/data/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/interfaces/data/common.py b/flytekit/interfaces/data/common.py deleted file mode 100644 index 1544b608b6..0000000000 --- a/flytekit/interfaces/data/common.py +++ /dev/null @@ -1,54 +0,0 @@ -class DataProxy(object): - def __init__(self, name: str): - self._name = name - - @property - def name(self) -> str: - return self._name - - def exists(self, path): - """ - :param path: - :rtype: bool: whether the file exists or not - """ - pass - - def download_directory(self, remote_path, local_path): - """ - :param Text remote_path: - :param Text local_path: - """ - pass - - def download(self, remote_path, local_path): - """ - :param Text remote_path: - :param Text local_path: - """ - pass - - def upload(self, file_path, to_path): - """ - :param Text file_path: - :param Text to_path: - """ - pass - - def upload_directory(self, local_path, remote_path): - """ - :param Text local_path: - :param Text remote_path: - """ - pass - - def get_random_path(self): - """ - :rtype: Text - """ - pass - - def get_random_directory(self): - """ - :rtype: Text - """ - pass diff --git a/flytekit/interfaces/data/data_proxy.py b/flytekit/interfaces/data/data_proxy.py deleted file mode 100644 index 7cf5dcae58..0000000000 --- a/flytekit/interfaces/data/data_proxy.py +++ /dev/null @@ -1,172 +0,0 @@ -from flytekit.common import constants as _constants -from flytekit.common import utils as _common_utils -from flytekit.common.exceptions import user as _user_exception -from flytekit.configuration import platform as _platform_config -from flytekit.configuration import sdk as _sdk_config -from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy -from flytekit.interfaces.data.http import http_data_proxy as _http_data_proxy -from flytekit.interfaces.data.local import local_file_proxy as _local_file_proxy -from flytekit.interfaces.data.s3 import s3proxy as _s3proxy - - -class LocalWorkingDirectoryContext(object): - _CONTEXTS = [] - - def __init__(self, directory): - self._directory = directory - - def __enter__(self): - self._CONTEXTS.append(self._directory) - - def __exit__(self, exc_type, exc_val, exc_tb): - self._CONTEXTS.pop() - - @classmethod - def get(cls): - return cls._CONTEXTS[-1] if cls._CONTEXTS else None - - -class _OutputDataContext(object): - _CONTEXTS = [_local_file_proxy.LocalFileProxy(_sdk_config.LOCAL_SANDBOX.get())] - - def __init__(self, context): - self._context = context - - def __enter__(self): - self._CONTEXTS.append(self._context) - - def __exit__(self, exc_type, exc_val, exc_tb): - self._CONTEXTS.pop() - - @classmethod - def get_active_proxy(cls): - return cls._CONTEXTS[-1] - - @classmethod - def get_default_proxy(cls): - return cls._CONTEXTS[0] - - -class LocalDataContext(_OutputDataContext): - def __init__(self, sandbox): - """ - :param Text sandbox: - """ - super(LocalDataContext, self).__init__(_local_file_proxy.LocalFileProxy(sandbox)) - - -class RemoteDataContext(_OutputDataContext): - _CLOUD_PROVIDER_TO_PROXIES = { - _constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy, - _constants.CloudProvider.GCP: _gcs_proxy.GCSProxy, - } - - def __init__(self, cloud_provider=None, raw_output_data_prefix_override=None): - """ - :param Optional[Text] cloud_provider: From flytekit.common.constants.CloudProvider enum - """ - cloud_provider = cloud_provider or _platform_config.CLOUD_PROVIDER.get() - proxy_class = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None) - if proxy_class is None: - raise _user_exception.FlyteAssertion( - "Configured cloud provider is not supported for data I/O. Received: {}, expected one of: {}".format( - cloud_provider, list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys()) - ) - ) - proxy = proxy_class(raw_output_data_prefix_override) - super(RemoteDataContext, self).__init__(proxy) - - -class Data(object): - # TODO: More proxies for more environments. - _DATA_PROXIES = { - "s3:/": _s3proxy.AwsS3Proxy(), - "gs:/": _gcs_proxy.GCSProxy(), - "http://": _http_data_proxy.HttpFileProxy(), - "https://": _http_data_proxy.HttpFileProxy(), - } - - @classmethod - def _load_data_proxy_by_path(cls, path): - """ - :param Text path: - :rtype: flytekit.interfaces.data.common.DataProxy - """ - for k, v in cls._DATA_PROXIES.items(): - if path.startswith(k): - return v - return _OutputDataContext.get_default_proxy() - - @classmethod - def data_exists(cls, path): - """ - :param Text path: - :rtype: bool: whether the file exists or not - """ - with _common_utils.PerformanceTimer("Check file exists {}".format(path)): - proxy = cls._load_data_proxy_by_path(path) - return proxy.exists(path) - - @classmethod - def get_data(cls, remote_path, local_path, is_multipart=False): - """ - :param Text remote_path: - :param Text local_path: - :param bool is_multipart: - """ - try: - with _common_utils.PerformanceTimer("Copying ({} -> {})".format(remote_path, local_path)): - proxy = cls._load_data_proxy_by_path(remote_path) - if is_multipart: - proxy.download_directory(remote_path, local_path) - else: - proxy.download(remote_path, local_path) - except Exception as ex: - raise _user_exception.FlyteAssertion( - "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" - "Original exception: {error_string}".format( - remote_path=remote_path, - local_path=local_path, - is_multipart=is_multipart, - error_string=str(ex), - ) - ) - - @classmethod - def put_data(cls, local_path, remote_path, is_multipart=False): - """ - :param Text local_path: - :param Text remote_path: - :param bool is_multipart: - """ - try: - with _common_utils.PerformanceTimer("Writing ({} -> {})".format(local_path, remote_path)): - proxy = cls._load_data_proxy_by_path(remote_path) - if is_multipart: - proxy.upload_directory(local_path, remote_path) - else: - proxy.upload(local_path, remote_path) - except Exception as ex: - raise _user_exception.FlyteAssertion( - "Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" - "Original exception: {error_string}".format( - remote_path=remote_path, - local_path=local_path, - is_multipart=is_multipart, - error_string=str(ex), - ) - ) - - @classmethod - def get_remote_path(cls): - """ - :rtype: Text - """ - return _OutputDataContext.get_active_proxy().get_random_path() - - @classmethod - def get_remote_directory(cls): - """ - :rtype: Text - """ - return _OutputDataContext.get_active_proxy().get_random_directory() diff --git a/flytekit/interfaces/data/gcs/__init__.py b/flytekit/interfaces/data/gcs/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/interfaces/data/gcs/gcs_proxy.py b/flytekit/interfaces/data/gcs/gcs_proxy.py deleted file mode 100644 index e8440299cd..0000000000 --- a/flytekit/interfaces/data/gcs/gcs_proxy.py +++ /dev/null @@ -1,150 +0,0 @@ -import os as _os -import sys as _sys -import uuid as _uuid - -from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException -from flytekit.configuration import gcp as _gcp_config -from flytekit.interfaces import random as _flyte_random -from flytekit.interfaces.data import common as _common_data -from flytekit.tools import subprocess as _subprocess - -if _sys.version_info >= (3,): - from shutil import which as _which -else: - from distutils.spawn import find_executable as _which - - -def _update_cmd_config_and_execute(cmd): - env = _os.environ.copy() - return _subprocess.check_call(cmd, env=env) - - -def _amend_path(path): - return _os.path.join(path, "*") if not path.endswith("*") else path - - -class GCSProxy(_common_data.DataProxy): - _GS_UTIL_CLI = "gsutil" - - def __init__(self, raw_output_data_prefix_override: str = None): - """ - :param raw_output_data_prefix_override: Instead of relying on the AWS or GCS configuration (see - S3_SHARD_FORMATTER for AWS and GCS_PREFIX for GCP) setting when computing the shard - path (_get_shard_path), use this prefix instead as a base. This code assumes that the - path passed in is correct. That is, an S3 path won't be passed in when running on GCP. - """ - self._raw_output_data_prefix_override = raw_output_data_prefix_override - super(GCSProxy, self).__init__(name="gcs-gsutil") - - @property - def raw_output_data_prefix_override(self) -> str: - return self._raw_output_data_prefix_override - - @staticmethod - def _check_binary(): - """ - Make sure that the `gsutil` cli is present - """ - if not _which(GCSProxy._GS_UTIL_CLI): - raise _FlyteUserException("gsutil (gcloud cli) not found! Please install.") - - @staticmethod - def _maybe_with_gsutil_parallelism(*gsutil_args): - """ - Check if we should run `gsutil` with the `-m` flag that enables - parallelism via multiple threads/processes. Additional tweaking of - this behavior can be achieved via the .boto configuration file. See: - https://cloud.google.com/storage/docs/boto-gsutil - """ - cmd = [GCSProxy._GS_UTIL_CLI] - if _gcp_config.GSUTIL_PARALLELISM.get(): - cmd.append("-m") - cmd.extend(gsutil_args) - - return cmd - - def exists(self, remote_path): - """ - :param Text remote_path: remote gs:// path - :rtype bool: whether the gs file exists or not - """ - GCSProxy._check_binary() - - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - cmd = [GCSProxy._GS_UTIL_CLI, "-q", "stat", remote_path] - try: - _update_cmd_config_and_execute(cmd) - return True - except Exception: - return False - - def download_directory(self, remote_path, local_path): - """ - :param Text remote_path: remote gs:// path - :param Text local_path: directory to copy to - """ - GCSProxy._check_binary() - - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(remote_path), local_path) - return _update_cmd_config_and_execute(cmd) - - def download(self, remote_path, local_path): - """ - :param Text remote_path: remote gs:// path - :param Text local_path: directory to copy to - """ - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - GCSProxy._check_binary() - - cmd = self._maybe_with_gsutil_parallelism("cp", remote_path, local_path) - return _update_cmd_config_and_execute(cmd) - - def upload(self, file_path, to_path): - """ - :param Text file_path: - :param Text to_path: - """ - GCSProxy._check_binary() - - cmd = self._maybe_with_gsutil_parallelism("cp", file_path, to_path) - return _update_cmd_config_and_execute(cmd) - - def upload_directory(self, local_path, remote_path): - """ - :param Text local_path: - :param Text remote_path: - """ - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - GCSProxy._check_binary() - - cmd = self._maybe_with_gsutil_parallelism( - "cp", - "-r", - _amend_path(local_path), - remote_path if remote_path.endswith("/") else remote_path + "/", - ) - return _update_cmd_config_and_execute(cmd) - - def get_random_path(self) -> str: - """ - If this object was created with a raw output data prefix, usually set by Propeller/Plugins at execution time - and piped all the way here, it will be used instead of referencing the GCS_PREFIX configuration. - """ - key = _uuid.UUID(int=_flyte_random.random.getrandbits(128)).hex - prefix = self.raw_output_data_prefix_override or _gcp_config.GCS_PREFIX.get() - return _os.path.join(prefix, key) - - def get_random_directory(self): - """ - :rtype: Text - """ - return self.get_random_path() + "/" diff --git a/flytekit/interfaces/data/http/__init__.py b/flytekit/interfaces/data/http/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/interfaces/data/http/http_data_proxy.py b/flytekit/interfaces/data/http/http_data_proxy.py deleted file mode 100644 index 33e1909906..0000000000 --- a/flytekit/interfaces/data/http/http_data_proxy.py +++ /dev/null @@ -1,80 +0,0 @@ -import requests as _requests - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.interfaces.data import common as _common_data - - -class HttpFileProxy(_common_data.DataProxy): - - _HTTP_OK = 200 - _HTTP_FORBIDDEN = 403 - _HTTP_NOT_FOUND = 404 - - def __init__(self): - super(HttpFileProxy, self).__init__(name="http") - - def exists(self, path): - """ - :param Text path: the path of the file - :rtype bool: whether the file exists or not - """ - rsp = _requests.head(path) - allowed_codes = { - type(self)._HTTP_OK, - type(self)._HTTP_NOT_FOUND, - type(self)._HTTP_FORBIDDEN, - } - if rsp.status_code not in allowed_codes: - raise _user_exceptions.FlyteValueException( - rsp.status_code, - "Data at {} could not be checked for existence. Expected one of: {}".format(path, allowed_codes), - ) - return rsp.status_code == type(self)._HTTP_OK - - def download_directory(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - raise _user_exceptions.FlyteAssertion("Reading data recursively from HTTP endpoint is not currently supported.") - - def download(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - - rsp = _requests.get(from_path) - if rsp.status_code != type(self)._HTTP_OK: - raise _user_exceptions.FlyteValueException( - rsp.status_code, - "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK), - ) - with open(to_path, "wb") as writer: - writer.write(rsp.content) - - def upload(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - raise _user_exceptions.FlyteAssertion("Writing data to HTTP endpoint is not currently supported.") - - def upload_directory(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - raise _user_exceptions.FlyteAssertion("Writing data to HTTP endpoint is not currently supported.") - - def get_random_path(self): - """ - :rtype: Text - """ - raise _user_exceptions.FlyteAssertion("Writing data to HTTP endpoint is not currently supported.") - - def get_random_directory(self): - """ - :rtype: Text - """ - raise _user_exceptions.FlyteAssertion("Writing data to HTTP endpoint is not currently supported.") diff --git a/flytekit/interfaces/data/local/__init__.py b/flytekit/interfaces/data/local/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/interfaces/data/local/local_file_proxy.py b/flytekit/interfaces/data/local/local_file_proxy.py deleted file mode 100644 index 2c9ea33cf5..0000000000 --- a/flytekit/interfaces/data/local/local_file_proxy.py +++ /dev/null @@ -1,92 +0,0 @@ -import os as _os -import uuid as _uuid -from distutils import dir_util as _dir_util -from shutil import copyfile as _copyfile - -from flytekit.interfaces import random as _flyte_random -from flytekit.interfaces.data import common as _common_data - - -def _make_local_path(path): - if not _os.path.exists(path): - try: - _os.makedirs(path) - except OSError: # Guard against race condition - if not _os.path.isdir(path): - raise - - -def strip_file_header(path: str) -> str: - if path.startswith("file://"): - return path.replace("file://", "", 1) - return path - - -class LocalFileProxy(_common_data.DataProxy): - def __init__(self, sandbox): - """ - :param Text sandbox: - """ - super().__init__(name="local") - self._sandbox = sandbox - - @property - def sandbox(self) -> str: - return self._sandbox - - def exists(self, path): - """ - :param Text path: the path of the file - :rtype bool: whether the file exists or not - """ - return _os.path.exists(strip_file_header(path)) - - def download_directory(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - if from_path != to_path: - _dir_util.copy_tree(strip_file_header(from_path), strip_file_header(to_path)) - - def download(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - _copyfile(strip_file_header(from_path), strip_file_header(to_path)) - - def upload(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - # Emulate s3's flat storage by automatically creating directory path - _make_local_path(_os.path.dirname(strip_file_header(to_path))) - # Write the object to a local file in the sandbox - _copyfile(strip_file_header(from_path), strip_file_header(to_path)) - - def upload_directory(self, from_path, to_path): - """ - :param Text from_path: - :param Text to_path: - """ - self.download_directory(from_path, to_path) - - def get_random_path(self): - """ - :rtype: Text - """ - # Create a 128-bit random hash because the birthday attack principle shows that there is about a 50% chance of a - # collision between objects when 2^(n/2) objects are created (where n is the number of bits in the hash). - # Assuming Flyte eventually creates 1 trillion pieces of data (~2 ^ 40), the likelihood - # of a collision is 10^-15 with 128-bit...or basically 0. - return _os.path.join(self._sandbox, _uuid.UUID(int=_flyte_random.random.getrandbits(128)).hex) - - def get_random_directory(self): - """ - :rtype: Text - """ - random_dir = self.get_random_path() + "/" - _make_local_path(random_dir) - return random_dir diff --git a/flytekit/interfaces/data/s3/__init__.py b/flytekit/interfaces/data/s3/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/interfaces/data/s3/s3proxy.py b/flytekit/interfaces/data/s3/s3proxy.py deleted file mode 100644 index ab11e7b738..0000000000 --- a/flytekit/interfaces/data/s3/s3proxy.py +++ /dev/null @@ -1,224 +0,0 @@ -import logging -import os as _os -import re as _re -import string as _string -import sys as _sys -import time -import uuid as _uuid -from typing import Dict, List - -from six import moves as _six_moves -from six import text_type as _text_type - -from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException -from flytekit.configuration import aws as _aws_config -from flytekit.interfaces import random as _flyte_random -from flytekit.interfaces.data import common as _common_data -from flytekit.tools import subprocess as _subprocess - -if _sys.version_info >= (3,): - from shutil import which as _which -else: - from distutils.spawn import find_executable as _which - - -def _update_cmd_config_and_execute(cmd: List[str]): - env = _os.environ.copy() - - if _aws_config.ENABLE_DEBUG.get(): - cmd.insert(1, "--debug") - - if _aws_config.S3_ENDPOINT.get() is not None: - cmd.insert(1, _aws_config.S3_ENDPOINT.get()) - cmd.insert(1, _aws_config.S3_ENDPOINT_ARG_NAME) - - if _aws_config.S3_ACCESS_KEY_ID.get() is not None: - env[_aws_config.S3_ACCESS_KEY_ID_ENV_NAME] = _aws_config.S3_ACCESS_KEY_ID.get() - - if _aws_config.S3_SECRET_ACCESS_KEY.get() is not None: - env[_aws_config.S3_SECRET_ACCESS_KEY_ENV_NAME] = _aws_config.S3_SECRET_ACCESS_KEY.get() - - retry = 0 - while True: - try: - return _subprocess.check_call(cmd, env=env) - except Exception as e: - logging.error(f"Exception when trying to execute {cmd}, reason: {str(e)}") - retry += 1 - if retry > _aws_config.RETRIES.get(): - raise - secs = _aws_config.BACKOFF_SECONDS.get() - logging.info(f"Sleeping before retrying again, after {secs} seconds") - time.sleep(secs) - logging.info("Retrying again") - - -def _extra_args(extra_args: Dict[str, str]) -> List[str]: - cmd = [] - if "ContentType" in extra_args: - cmd += ["--content-type", extra_args["ContentType"]] - if "ContentEncoding" in extra_args: - cmd += ["--content-encoding", extra_args["ContentEncoding"]] - if "ACL" in extra_args: - cmd += ["--acl", extra_args["ACL"]] - return cmd - - -class AwsS3Proxy(_common_data.DataProxy): - _AWS_CLI = "aws" - _SHARD_CHARACTERS = [_text_type(x) for x in _six_moves.range(10)] + list(_string.ascii_lowercase) - - def __init__(self, raw_output_data_prefix_override: str = None): - """ - :param raw_output_data_prefix_override: Instead of relying on the AWS or GCS configuration (see - S3_SHARD_FORMATTER for AWS and GCS_PREFIX for GCP) setting when computing the shard - path (_get_shard_path), use this prefix instead as a base. This code assumes that the - path passed in is correct. That is, an S3 path won't be passed in when running on GCP. - """ - super().__init__(name="awscli-s3") - self._raw_output_data_prefix_override = raw_output_data_prefix_override - - @property - def raw_output_data_prefix_override(self) -> str: - return self._raw_output_data_prefix_override - - @staticmethod - def _check_binary(): - """ - Make sure that the AWS cli is present - """ - if not _which(AwsS3Proxy._AWS_CLI): - raise _FlyteUserException("AWS CLI not found at Please install.") - - @staticmethod - def _split_s3_path_to_bucket_and_key(path): - """ - :param Text path: - :rtype: (Text, Text) - """ - path = path[len("s3://") :] - first_slash = path.index("/") - return path[:first_slash], path[first_slash + 1 :] - - def exists(self, remote_path): - """ - :param Text remote_path: remote s3:// path - :rtype bool: whether the s3 file exists or not - """ - AwsS3Proxy._check_binary() - - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - bucket, file_path = self._split_s3_path_to_bucket_and_key(remote_path) - cmd = [ - AwsS3Proxy._AWS_CLI, - "s3api", - "head-object", - "--bucket", - bucket, - "--key", - file_path, - ] - try: - _update_cmd_config_and_execute(cmd) - return True - except Exception as ex: - # The s3api command returns an error if the object does not exist. The error message contains - # the http status code: "An error occurred (404) when calling the HeadObject operation: Not Found" - # This is a best effort for returning if the object does not exist by searching - # for existence of (404) in the error message. This should not be needed when we get off the cli and use lib - if _re.search("(404)", _text_type(ex)): - return False - else: - raise ex - - def download_directory(self, remote_path, local_path): - """ - :param Text remote_path: remote s3:// path - :param Text local_path: directory to copy to - """ - AwsS3Proxy._check_binary() - - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - cmd = [AwsS3Proxy._AWS_CLI, "s3", "cp", "--recursive", remote_path, local_path] - return _update_cmd_config_and_execute(cmd) - - def download(self, remote_path, local_path): - """ - :param Text remote_path: remote s3:// path - :param Text local_path: directory to copy to - """ - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - AwsS3Proxy._check_binary() - cmd = [AwsS3Proxy._AWS_CLI, "s3", "cp", remote_path, local_path] - return _update_cmd_config_and_execute(cmd) - - def upload(self, file_path, to_path): - """ - :param Text file_path: - :param Text to_path: - """ - AwsS3Proxy._check_binary() - - extra_args = { - "ACL": "bucket-owner-full-control", - } - - cmd = [AwsS3Proxy._AWS_CLI, "s3", "cp"] - cmd.extend(_extra_args(extra_args)) - cmd += [file_path, to_path] - - return _update_cmd_config_and_execute(cmd) - - def upload_directory(self, local_path, remote_path): - """ - :param Text local_path: - :param Text remote_path: - """ - extra_args = { - "ACL": "bucket-owner-full-control", - } - - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - AwsS3Proxy._check_binary() - cmd = [AwsS3Proxy._AWS_CLI, "s3", "cp", "--recursive"] - cmd.extend(_extra_args(extra_args)) - cmd += [local_path, remote_path] - return _update_cmd_config_and_execute(cmd) - - def get_random_path(self): - """ - :rtype: Text - """ - # Create a 128-bit random hash because the birthday attack principle shows that there is about a 50% chance of a - # collision between objects when 2^(n/2) objects are created (where n is the number of bits in the hash). - # Assuming Flyte eventually creates 1 trillion pieces of data (~2 ^ 40), the likelihood - # of a collision is 10^-15 with 128-bit...or basically 0. - key = _uuid.UUID(int=_flyte_random.random.getrandbits(128)).hex - return _os.path.join(self._get_shard_path(), key) - - def get_random_directory(self): - """ - :rtype: Text - """ - return self.get_random_path() + "/" - - def _get_shard_path(self) -> str: - """ - If this object was created with a raw output data prefix, usually set by Propeller/Plugins at execution time - and piped all the way here, it will be used instead of referencing the S3 shard configuration. - """ - if self.raw_output_data_prefix_override: - return self.raw_output_data_prefix_override - - shard = "" - for _ in _six_moves.range(_aws_config.S3_SHARD_STRING_LENGTH.get()): - shard += _flyte_random.random.choice(self._SHARD_CHARACTERS) - return _aws_config.S3_SHARD_FORMATTER.get().format(shard) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 3ef6291ac7..63253bf399 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -1,7 +1,6 @@ import abc as _abc import json as _json -import six as _six from flyteidl.admin import common_pb2 as _common_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -58,7 +57,7 @@ def short_string(self): """ :rtype: Text """ - return _six.text_type(self.to_flyte_idl()) + return str(self.to_flyte_idl()) def verbose_string(self): """ @@ -333,7 +332,7 @@ def to_flyte_idl(self): """ :rtype: dict[Text, Text] """ - return _common_pb2.Labels(values={k: v for k, v in _six.iteritems(self.values)}) + return _common_pb2.Labels(values={k: v for k, v in self.values.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -341,7 +340,7 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.common_pb2.Labels pb2_object: :rtype: Labels """ - return cls({k: v for k, v in _six.iteritems(pb2_object.values)}) + return cls({k: v for k, v in pb2_object.values.items()}) class Annotations(FlyteIdlEntity): @@ -361,7 +360,7 @@ def to_flyte_idl(self): """ :rtype: _common_pb2.Annotations """ - return _common_pb2.Annotations(values={k: v for k, v in _six.iteritems(self.values)}) + return _common_pb2.Annotations(values={k: v for k, v in self.values.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -369,7 +368,7 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.common_pb2.Annotations pb2_object: :rtype: Annotations """ - return cls({k: v for k, v in _six.iteritems(pb2_object.values)}) + return cls({k: v for k, v in pb2_object.values.items()}) class UrlBlob(FlyteIdlEntity): diff --git a/flytekit/models/core/compiler.py b/flytekit/models/core/compiler.py index 3246ee22b3..929f816227 100644 --- a/flytekit/models/core/compiler.py +++ b/flytekit/models/core/compiler.py @@ -1,4 +1,3 @@ -import six as _six from flyteidl.core import compiler_pb2 as _compiler_pb2 from flytekit.models import common as _common @@ -61,8 +60,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.compiler_pb2.ConnectionSet """ return _compiler_pb2.ConnectionSet( - upstream={k: v.to_flyte_idl() for k, v in _six.iteritems(self.upstream)}, - downstream={k: v.to_flyte_idl() for k, v in _six.iteritems(self.upstream)}, + upstream={k: v.to_flyte_idl() for k, v in self.upstream.items()}, + downstream={k: v.to_flyte_idl() for k, v in self.upstream.items()}, ) @classmethod @@ -72,8 +71,8 @@ def from_flyte_idl(cls, p): :rtype: ConnectionSet """ return cls( - upstream={k: ConnectionSet.IdList.from_flyte_idl(v) for k, v in _six.iteritems(p.upstream)}, - downstream={k: ConnectionSet.IdList.from_flyte_idl(v) for k, v in _six.iteritems(p.downstream)}, + upstream={k: ConnectionSet.IdList.from_flyte_idl(v) for k, v in p.upstream.items()}, + downstream={k: ConnectionSet.IdList.from_flyte_idl(v) for k, v in p.downstream.items()}, ) diff --git a/flytekit/models/interface.py b/flytekit/models/interface.py index 5364d39c3e..b0c9fc882a 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/interface.py @@ -1,6 +1,5 @@ import typing -import six as _six from flyteidl.core import interface_pb2 as _interface_pb2 from flytekit.models import common as _common @@ -73,7 +72,7 @@ def to_flyte_idl(self): """ :rtype: dict[Text, Variable] """ - return _interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.variables)}) + return _interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in self.variables.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -81,7 +80,7 @@ def from_flyte_idl(cls, pb2_object): :param dict[Text, Variable] pb2_object: :rtype: VariableMap """ - return cls({k: Variable.from_flyte_idl(v) for k, v in _six.iteritems(pb2_object.variables)}) + return cls({k: Variable.from_flyte_idl(v) for k, v in pb2_object.variables.items()}) class TypedInterface(_common.FlyteIdlEntity): @@ -106,10 +105,8 @@ def outputs(self) -> typing.Dict[str, Variable]: def to_flyte_idl(self) -> _interface_pb2.TypedInterface: return _interface_pb2.TypedInterface( - inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.inputs)}), - outputs=_interface_pb2.VariableMap( - variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.outputs)} - ), + inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in self.inputs.items()}), + outputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in self.outputs.items()}), ) @classmethod @@ -118,8 +115,8 @@ def from_flyte_idl(cls, proto: _interface_pb2.TypedInterface) -> "TypedInterface :param proto: """ return cls( - inputs={k: Variable.from_flyte_idl(v) for k, v in _six.iteritems(proto.inputs.variables)}, - outputs={k: Variable.from_flyte_idl(v) for k, v in _six.iteritems(proto.outputs.variables)}, + inputs={k: Variable.from_flyte_idl(v) for k, v in proto.inputs.variables.items()}, + outputs={k: Variable.from_flyte_idl(v) for k, v in proto.outputs.variables.items()}, ) @@ -211,7 +208,7 @@ def to_flyte_idl(self): :rtype: flyteidl.core.interface_pb2.ParameterMap """ return _interface_pb2.ParameterMap( - parameters={k: v.to_flyte_idl() for k, v in _six.iteritems(self.parameters)}, + parameters={k: v.to_flyte_idl() for k, v in self.parameters.items()}, ) @classmethod @@ -220,4 +217,4 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.interface_pb2.ParameterMap pb2_object: :rtype: ParameterMap """ - return cls(parameters={k: Parameter.from_flyte_idl(v) for k, v in _six.iteritems(pb2_object.parameters)}) + return cls(parameters={k: Parameter.from_flyte_idl(v) for k, v in pb2_object.parameters.items()}) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index f6947ac75c..bc398ab7a7 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -5,7 +5,7 @@ from flyteidl.core import literals_pb2 as _literals_pb2 from google.protobuf.struct_pb2 import Struct -from flytekit.common.exceptions import user as _user_exceptions +from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common from flytekit.models.core import types as _core_types from flytekit.models.types import OutputReference as _OutputReference diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 76caa0444d..caf2471e58 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -1,23 +1,18 @@ import json as _json import typing -import six as _six from flyteidl.admin import task_pb2 as _admin_task from flyteidl.core import compiler_pb2 as _compiler from flyteidl.core import literals_pb2 as _literals_pb2 from flyteidl.core import tasks_pb2 as _core_task -from flyteidl.plugins import spark_pb2 as _spark_task from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct -from flytekit.common.exceptions import user as _user_exceptions from flytekit.models import common as _common from flytekit.models import interface as _interface from flytekit.models import literals as _literals from flytekit.models import security as _sec from flytekit.models.core import identifier as _identifier -from flytekit.plugins import flyteidl as _lazy_flyteidl -from flytekit.sdk.spark_types import SparkType as _spark_type class Resources(_common.FlyteIdlEntity): @@ -617,151 +612,6 @@ def from_flyte_idl(cls, pb2_object): return cls(template=TaskTemplate.from_flyte_idl(pb2_object.template)) -class SparkJob(_common.FlyteIdlEntity): - """ - This model is deprecated and will be removed in 1.0.0. Please use the definition in the - flytekit spark plugin instead. - """ - - def __init__( - self, - spark_type, - application_file, - main_class, - spark_conf, - hadoop_conf, - executor_path, - ): - """ - This defines a SparkJob target. It will execute the appropriate SparkJob. - - :param application_file: The main application file to execute. - :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. - :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. - """ - self._application_file = application_file - self._spark_type = spark_type - self._main_class = main_class - self._executor_path = executor_path - self._spark_conf = spark_conf - self._hadoop_conf = hadoop_conf - - def with_overrides( - self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None - ) -> "SparkJob": - if not new_spark_conf: - new_spark_conf = self.spark_conf - - if not new_hadoop_conf: - new_hadoop_conf = self.hadoop_conf - - return SparkJob( - spark_type=self.spark_type, - application_file=self.application_file, - main_class=self.main_class, - spark_conf=new_spark_conf, - hadoop_conf=new_hadoop_conf, - executor_path=self.executor_path, - ) - - @property - def main_class(self): - """ - The main class to execute - :rtype: Text - """ - return self._main_class - - @property - def spark_type(self): - """ - Spark Job Type - :rtype: Text - """ - return self._spark_type - - @property - def application_file(self): - """ - The main application file to execute - :rtype: Text - """ - return self._application_file - - @property - def executor_path(self): - """ - The python executable to use - :rtype: Text - """ - return self._executor_path - - @property - def spark_conf(self): - """ - A definition of key-value pairs for spark config for the job. - :rtype: dict[Text, Text] - """ - return self._spark_conf - - @property - def hadoop_conf(self): - """ - A definition of key-value pairs for hadoop config for the job. - :rtype: dict[Text, Text] - """ - return self._hadoop_conf - - def to_flyte_idl(self): - """ - :rtype: flyteidl.plugins.spark_pb2.SparkJob - """ - - if self.spark_type == _spark_type.PYTHON: - application_type = _spark_task.SparkApplication.PYTHON - elif self.spark_type == _spark_type.JAVA: - application_type = _spark_task.SparkApplication.JAVA - elif self.spark_type == _spark_type.SCALA: - application_type = _spark_task.SparkApplication.SCALA - elif self.spark_type == _spark_type.R: - application_type = _spark_task.SparkApplication.R - else: - raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") - - return _spark_task.SparkJob( - applicationType=application_type, - mainApplicationFile=self.application_file, - mainClass=self.main_class, - executorPath=self.executor_path, - sparkConf=self.spark_conf, - hadoopConf=self.hadoop_conf, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.plugins.spark_pb2.SparkJob pb2_object: - :rtype: SparkJob - """ - - application_type = _spark_type.PYTHON - if pb2_object.type == _spark_task.SparkApplication.JAVA: - application_type = _spark_type.JAVA - elif pb2_object.type == _spark_task.SparkApplication.SCALA: - application_type = _spark_type.SCALA - elif pb2_object.type == _spark_task.SparkApplication.R: - application_type = _spark_type.R - - return cls( - type=application_type, - spark_conf=pb2_object.sparkConf, - application_file=pb2_object.mainApplicationFile, - main_class=pb2_object.mainClass, - hadoop_conf=pb2_object.hadoopConf, - executor_path=pb2_object.executorPath, - ) - - class IOStrategy(_common.FlyteIdlEntity): """ Provides methods to manage data in and out of the Raw container using Download Modes. This can only be used if DataLoadingConfig is enabled. @@ -930,8 +780,8 @@ def to_flyte_idl(self): command=self.command, args=self.args, resources=self.resources.to_flyte_idl(), - env=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in _six.iteritems(self.env)], - config=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in _six.iteritems(self.config)], + env=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in self.env.items()], + config=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in self.config.items()], data_config=self._data_loading_config.to_flyte_idl() if self._data_loading_config else None, ) @@ -1045,72 +895,3 @@ def from_flyte_idl(cls, pb2_object: _core_task.Sql): statement=pb2_object.statement, dialect=pb2_object.dialect, ) - - -class SidecarJob(_common.FlyteIdlEntity): - def __init__(self, pod_spec, primary_container_name, annotations=None, labels=None): - """ - A sidecar job represents the full kubernetes pod spec and related metadata required for executing a sidecar - task. - - :param pod_spec: k8s.io.api.core.v1.PodSpec - :param primary_container_name: Text - :param dict[Text, Text] annotations: - :param dict[Text, Text] labels: - """ - self._pod_spec = pod_spec - self._primary_container_name = primary_container_name - self._annotations = annotations - self._labels = labels - - @property - def pod_spec(self): - """ - :rtype: k8s.io.api.core.v1.PodSpec - """ - return self._pod_spec - - @property - def primary_container_name(self): - """ - :rtype: Text - """ - return self._primary_container_name - - @property - def annotations(self): - """ - :rtype: dict[Text,Text] - """ - return self._annotations - - @property - def labels(self): - """ - :rtype: dict[Text,Text] - """ - return self._labels - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.tasks_pb2.SidecarJob - """ - return _lazy_flyteidl.plugins.sidecar_pb2.SidecarJob( - pod_spec=self.pod_spec, - primary_container_name=self.primary_container_name, - annotations=self.annotations, - labels=self.labels, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.task_pb2.Task pb2_object: - :rtype: Container - """ - return cls( - pod_spec=pb2_object.pod_spec, - primary_container_name=pb2_object.primary_container_name, - annotations=pb2_object.annotations, - labels=pb2_object.labels, - ) diff --git a/flytekit/plugins/__init__.py b/flytekit/plugins/__init__.py deleted file mode 100644 index 61333933f1..0000000000 --- a/flytekit/plugins/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -This file is for old style plugins - for new plugins that work with the Python native-typed Flytekit, please -refer to the plugin specific directory underneath the plugins folder at the top level of this repository. -""" -from flytekit.tools import lazy_loader as _lazy_loader - -pyspark = _lazy_loader.lazy_load_module("pyspark") # type: _lazy_loader._LazyLoadModule - -k8s = _lazy_loader.lazy_load_module("k8s") # type: _lazy_loader._LazyLoadModule -type(k8s).add_sub_module("io.api.core.v1.generated_pb2") -type(k8s).add_sub_module("io.apimachinery.pkg.api.resource.generated_pb2") - -flyteidl = _lazy_loader.lazy_load_module("flyteidl") # type: _lazy_loader._LazyLoadModule -type(flyteidl).add_sub_module("plugins.sidecar_pb2") - -numpy = _lazy_loader.lazy_load_module("numpy") # type: _lazy_loader._LazyLoadModule -pandas = _lazy_loader.lazy_load_module("pandas") # type: _lazy_loader._LazyLoadModule - -hmsclient = _lazy_loader.lazy_load_module("hmsclient") # type: _lazy_loader._LazyLoadModule -type(hmsclient).add_sub_module("genthrift.hive_metastore.ttypes") - -_lazy_loader.LazyLoadPlugin("spark", ["pyspark>=2.4.0,<3.0.0"], [pyspark]) - -_lazy_loader.LazyLoadPlugin("spark3", ["pyspark>=3.0.0"], [pyspark]) - -_lazy_loader.LazyLoadPlugin("sidecar", ["k8s-proto>=0.0.3,<1.0.0"], [k8s, flyteidl]) - -_lazy_loader.LazyLoadPlugin( - "schema", - ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>=0.11.0,<1.0.0"], - [numpy, pandas], -) - -_lazy_loader.LazyLoadPlugin("hive_sensor", ["hmsclient>=0.0.1,<1.0.0"], [hmsclient]) diff --git a/flytekit/remote/component_nodes.py b/flytekit/remote/component_nodes.py index 367cab8997..877a6d6494 100644 --- a/flytekit/remote/component_nodes.py +++ b/flytekit/remote/component_nodes.py @@ -1,7 +1,7 @@ import logging as _logging from typing import Dict -from flytekit.common.exceptions import system as _system_exceptions +from flytekit.exceptions import system as _system_exceptions from flytekit.models import launch_plan as _launch_plan_model from flytekit.models import task as _task_model from flytekit.models.core import identifier as id_models diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index b581769f52..64b3fcb7d8 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List, Optional, Union -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.exceptions import user as user_exceptions from flytekit.core.type_engine import LiteralsResolver +from flytekit.exceptions import user as _user_exceptions +from flytekit.exceptions import user as user_exceptions from flytekit.models import execution as execution_models from flytekit.models import node_execution as node_execution_models from flytekit.models.admin import task_execution as admin_task_execution_models diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py index 016e3a3489..b5dddea03f 100644 --- a/flytekit/remote/launch_plan.py +++ b/flytekit/remote/launch_plan.py @@ -1,10 +1,7 @@ from typing import Optional -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions from flytekit.core.interface import Interface from flytekit.core.type_engine import TypeEngine -from flytekit.engines.flyte import engine as _flyte_engine from flytekit.models import interface as _interface_models from flytekit.models import launch_plan as _launch_plan_models from flytekit.models.core import identifier as id_models @@ -94,14 +91,5 @@ def guessed_python_interface(self, value): return self._python_interface = value - @_exception_scopes.system_entry_point - def update(self, state: _launch_plan_models.LaunchPlanState): - if not self.id: - raise _user_exceptions.FlyteAssertion( - "Failed to update launch plan because the launch plan's ID is not set. Please call register to fetch " - "or register the identifier first" - ) - return _flyte_engine.get_client().update_launch_plan(self.id, state) - def __repr__(self) -> str: return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface} WF ID: {self.workflow_id})" diff --git a/flytekit/remote/nodes.py b/flytekit/remote/nodes.py index f8ae1b2d6a..4131905c5b 100644 --- a/flytekit/remote/nodes.py +++ b/flytekit/remote/nodes.py @@ -3,11 +3,11 @@ import logging as _logging from typing import Dict, List, Optional, Union -from flytekit.common import constants as _constants -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import hash as _hash_mixin +from flytekit.core import constants as _constants +from flytekit.core import hash as _hash_mixin from flytekit.core.promise import NodeOutput +from flytekit.exceptions import system as _system_exceptions +from flytekit.exceptions import user as _user_exceptions from flytekit.models import launch_plan as _launch_plan_model from flytekit.models import task as _task_model from flytekit.models.core import identifier as id_models diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 2178f148d3..44d5e8617c 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -19,14 +19,14 @@ from flyteidl.core import literals_pb2 as literals_pb2 from flytekit.clients.friendly import SynchronousFlyteClient -from flytekit.common import utils as common_utils -from flytekit.common.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException from flytekit.configuration import internal from flytekit.configuration import platform as platform_config from flytekit.configuration import sdk as sdk_config from flytekit.configuration import set_flyte_config_file -from flytekit.core import context_manager +from flytekit.core import constants, context_manager, utils from flytekit.core.interface import Interface +from flytekit.exceptions import user as user_exceptions +from flytekit.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException from flytekit.loggers import remote_logger from flytekit.models import filters as filter_models from flytekit.models.admin import common as admin_common_models @@ -39,9 +39,6 @@ from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions from flytekit.clis.flyte_cli.main import _detect_default_config_file from flytekit.clis.sdk_in_container import serialize -from flytekit.common import constants -from flytekit.common.exceptions import user as user_exceptions -from flytekit.common.translator import FlyteControlPlaneEntity, FlyteLocalEntity, get_serializable from flytekit.configuration import auth as auth_config from flytekit.configuration.internal import DOMAIN, PROJECT from flytekit.core.base_task import PythonTask @@ -68,6 +65,7 @@ from flytekit.remote.nodes import FlyteNode from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow +from flytekit.tools.translator import FlyteControlPlaneEntity, FlyteLocalEntity, get_serializable ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse] @@ -1297,7 +1295,7 @@ def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> liter tmp_name = os.path.join(ctx.file_access.local_sandbox_dir, "inputs.pb") ctx.file_access.get_data(execution_data.inputs.url, tmp_name) return literal_models.LiteralMap.from_flyte_idl( - common_utils.load_proto_from_file(literals_pb2.LiteralMap, tmp_name) + utils.load_proto_from_file(literals_pb2.LiteralMap, tmp_name) ) return literal_models.LiteralMap({}) @@ -1310,6 +1308,6 @@ def _get_output_literal_map(self, execution_data: ExecutionDataResponse) -> lite tmp_name = os.path.join(ctx.file_access.local_sandbox_dir, "outputs.pb") ctx.file_access.get_data(execution_data.outputs.url, tmp_name) return literal_models.LiteralMap.from_flyte_idl( - common_utils.load_proto_from_file(literals_pb2.LiteralMap, tmp_name) + utils.load_proto_from_file(literals_pb2.LiteralMap, tmp_name) ) return literal_models.LiteralMap({}) diff --git a/flytekit/remote/task.py b/flytekit/remote/task.py index 0c48f5f15e..1ff99549d6 100644 --- a/flytekit/remote/task.py +++ b/flytekit/remote/task.py @@ -1,6 +1,6 @@ from typing import Optional -from flytekit.common.mixins import hash as _hash_mixin +from flytekit.core import hash as _hash_mixin from flytekit.core.interface import Interface from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py index 396f377500..010a703187 100644 --- a/flytekit/remote/workflow.py +++ b/flytekit/remote/workflow.py @@ -2,11 +2,11 @@ from typing import Dict, List, Optional -from flytekit.common import constants as _constants -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.mixins import hash as _hash_mixin +from flytekit.core import constants as _constants +from flytekit.core import hash as _hash_mixin from flytekit.core.interface import Interface from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions import user as _user_exceptions from flytekit.models import launch_plan as launch_plan_models from flytekit.models import task as _task_models from flytekit.models.core import compiler as compiler_models diff --git a/flytekit/sdk/__init__.py b/flytekit/sdk/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/sdk/exceptions.py b/flytekit/sdk/exceptions.py deleted file mode 100644 index 89874a5518..0000000000 --- a/flytekit/sdk/exceptions.py +++ /dev/null @@ -1,11 +0,0 @@ -from flytekit.common.exceptions import user as _user - - -class RecoverableException(_user.FlyteRecoverableException): - """ - Raise an exception of this type if user code detects an error and would like to force a retry of the entire task. - Any exception raised from user code other than RecoverableException will NOT be considered retryable and the task - will fail without additional retries. - """ - - pass diff --git a/flytekit/sdk/spark_types.py b/flytekit/sdk/spark_types.py deleted file mode 100644 index 9477895fac..0000000000 --- a/flytekit/sdk/spark_types.py +++ /dev/null @@ -1,8 +0,0 @@ -import enum - - -class SparkType(enum.Enum): - PYTHON = 1 - SCALA = 2 - JAVA = 3 - R = 4 diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py deleted file mode 100644 index 0c73ad66b7..0000000000 --- a/flytekit/sdk/tasks.py +++ /dev/null @@ -1,1244 +0,0 @@ -import datetime as _datetime - -import six as _six - -from flytekit.common import constants as _common_constants -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import generic_spark_task as _sdk_generic_spark_task -from flytekit.common.tasks import hive_task as _sdk_hive_tasks -from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic -from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks -from flytekit.common.tasks import sidecar_task as _sdk_sidecar_tasks -from flytekit.common.tasks import spark_task as _sdk_spark_tasks -from flytekit.common.tasks import task as _task -from flytekit.common.types import helpers as _type_helpers -from flytekit.models import interface as _interface_model -from flytekit.sdk.spark_types import SparkType as _spark_type - - -def inputs(_task_template=None, **kwargs): - """ - Decorator that provides input definitions to a decorated task definition. - - .. note:: - - Certain tasks have special input behavior. See comments on each task decorator for more information. - - .. code-block:: python - - @inputs(in1=Types.Integer, in2=[Types.String], in3=[[[Types.Integer]]]) - @outputs(out1=Types.Integer, out2=Types.String) - @python_task - def my_task(wf_params, in1, in2, out1, out2): - pass - - :param flytekit.common.tasks.sdk_runnable.SdkRunnableTask _task_template: Do not declare directly. This is the - decorated task template. - :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] kwargs: Arbitrary keyword arguments for input - name and type. - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - - def apply_inputs_wrapper(task): - if not isinstance(task, _task.SdkTask): - additional_msg = ( - "Inputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, - task.__name__ if hasattr(task, "__name__") else "", - ) - ) - raise _user_exceptions.FlyteTypeException( - expected_type=_sdk_runnable_tasks.SdkRunnableTask, - received_type=type(task), - received_value=task, - additional_msg=additional_msg, - ) - for k, v in _six.iteritems(kwargs): - kwargs[k] = _interface_model.Variable( - _type_helpers.python_std_to_sdk_type(v).to_flyte_literal_type(), "" - ) # TODO: Support descriptions - - task.add_inputs(kwargs) - return task - - if _task_template is not None: - return apply_inputs_wrapper(_task_template) - else: - return apply_inputs_wrapper - - -def outputs(_task_template=None, **kwargs): - """ - Decorator that provides output definitions to a decorated task definition. - - .. note:: - - Certain tasks have special output behavior. See comments on each task decorator for more information. - - .. code-block:: python - - @outputs(out1=Types.Integer, out2=Types.String) - @python_task - def my_task(wf_params, out1, out2): - out1.set(123) - out2.set('hello world!') - - :param flytekit.common.tasks.sdk_runnable.SdkRunnableTask _task_template: Do not declare directly. This is the - decorated task template. - :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] kwargs: Arbitrary keyword arguments for input - name and type. - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - - def apply_outputs_wrapper(task): - if not isinstance(task, _sdk_runnable_tasks.SdkRunnableTask) and not isinstance( - task, _nb_tasks.SdkNotebookTask - ): - additional_msg = ( - "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, - task.__name__ if hasattr(task, "__name__") else "", - ) - ) - raise _user_exceptions.FlyteTypeException( - expected_type=_sdk_runnable_tasks.SdkRunnableTask, - received_type=type(task), - received_value=task, - additional_msg=additional_msg, - ) - for k, v in _six.iteritems(kwargs): - kwargs[k] = _interface_model.Variable( - _type_helpers.python_std_to_sdk_type(v).to_flyte_literal_type(), "" - ) # TODO: Support descriptions - - task.add_outputs(kwargs) - return task - - if _task_template is not None: - return apply_outputs_wrapper(_task_template) - else: - return apply_outputs_wrapper - - -def python_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - cache_serialize=False, - cls=None, -): - """ - Decorator to create a Python Task definition. This task will run as a single unit of work on the platform. - - .. code-block:: python - - @inputs(int_list=[Types.Integer]) - @outputs(sum_of_list=Types.Integer - @python_task - def my_task(wf_params, int_list, sum_of_list): - sum_of_list.set(sum(int_list)) - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param bool interruptible: [optional] boolean describing if the task is interruptible. - - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - - :param bool cache_serialize: [optional] boolean describing if instances of this cachable task should be executed - in serial. This means only a single instances executes and other concurrent executions wait for it to complete - and reuse the cached outputs. - - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to - inject bespoke logic into the base Flyte programming model. - - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - - def wrapper(fn): - return (cls or _sdk_runnable_tasks.SdkRunnableTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.PYTHON_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - environment=environment, - cache_serializable=cache_serialize, - custom={}, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper - - -def dynamic_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - allowed_failure_ratio=None, - max_concurrency=None, - environment=None, - cache_serialize=False, - cls=None, -): - """ - Decorator to create a custom dynamic task definition. Dynamic tasks should be used to split up work into - an arbitrary number of parallel sub-tasks, or workflows. - - .. code-block:: python - - @outputs(out=Types.Integer) - @python_task - def my_sub_task(wf_params, out): - out.set(randint()) - - @outputs(out=[Types.Integer]) - @dynamic_task - def my_task(wf_params, out): - out_list = [] - for i in xrange(100): - out_list.append(my_sub_task().outputs.out) - out.set(out_list) - - .. note:: - - All outputs of a batch task must be a list. This is because the individual outputs of sub-tasks should be - appended into a list. There cannot be aggregation of outputs done in this task. To accomplish aggregation, - it is recommended that a python_task take the outputs of this task as input and do the necessary work. - If a sub-task does not contribute an output, it must be yielded from the task with the `yield` keyword or - returned from the task in a list. If this isn't done, the sub-task will not be executed. - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed. - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param bool interruptible: [optional] boolean describing if the task is interruptible. - - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - :param float allowed_failure_ratio: [optional] float value describing the ratio of sub-tasks that may fail before - the master batch task considers itself a failure. By default, the value is 0 so if any task fails, the master - batch task will be marked a failure. If specified, the value must be between 0 and 1 inclusive. In the event a - non-zero value is specified, downstream tasks must be able to accept None values as outputs from individual - sub-tasks because the output values will be set to None for any sub-task that fails. - :param int max_concurrency: [optional] integer value describing the maximum number of tasks to run concurrently. - This is a stand-in pending better concurrency controls for special use-cases. The existence of this parameter - is not guaranteed between versions and therefore it is NOT recommended that it be used. - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - :param bool cache_serialize: [optional] boolean describing if instances of this cachable task should be executed - in serial. This means only a single instances executes and other concurrent executions wait for it to complete - and reuse the cached outputs. - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. Generally, it should be a - subclass of flytekit.common.tasks.sdk_dynamic.SdkDynamicTask. A user can use this parameter to inject bespoke - logic into the base Flyte programming model. - :rtype: flytekit.common.tasks.sdk_runnable.SdkDynamicTask - """ - - def wrapper(fn): - return (cls or _sdk_dynamic.SdkDynamicTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.DYNAMIC_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - allowed_failure_ratio=allowed_failure_ratio, - max_concurrency=max_concurrency, - environment=environment or {}, - cache_serializable=cache_serialize, - custom={}, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper - - -def spark_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - cache=False, - timeout=None, - spark_conf=None, - hadoop_conf=None, - environment=None, - cache_serialize=False, - cls=None, -): - """ - Decorator to create a spark task. This task will connect to a Spark cluster, configure the environment, - and then execute the code within the _task_function as the Spark driver program. - - .. code-block:: python - - @inputs(a=Types.Integer) - @spark_task( - spark_conf={ - 'spark.executor.cores': '7', - 'spark.executor.instances': '31', - 'spark.executor.memory': '32G' - } - ) - def sparky(wf_params, spark_context, a): - pass - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - :param dict[Text,Text] spark_conf: A definition of key-value pairs for spark config for the job. - :param dict[Text,Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - :param bool cache_serialize: [optional] boolean describing if instances of this cachable task should be executed - in serial. This means only a single instances executes and other concurrent executions wait for it to complete - and reuse the cached outputs. - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. Generally, it should be a - subclass of flytekit.common.tasks.spark_task.SdkSparkTask. A user can use this parameter to inject bespoke - logic into the base Flyte programming model. - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - - def wrapper(fn): - return (cls or _sdk_spark_tasks.SdkSparkTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.SPARK_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - spark_type=_spark_type.PYTHON, - deprecated=deprecated, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - spark_conf=spark_conf or {}, - hadoop_conf=hadoop_conf or {}, - environment=environment or {}, - cache_serializable=cache_serialize, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper - - -def generic_spark_task( - spark_type, - main_class, - main_application_file, - cache_version="", - retries=0, - interruptible=None, - inputs=None, - deprecated="", - cache=False, - timeout=None, - spark_conf=None, - hadoop_conf=None, - environment=None, - cache_serialize=False, -): - """ - Create a generic spark task. This task will connect to a Spark cluster, configure the environment, - and then execute the mainClass code as the Spark driver program. - - """ - - return _sdk_generic_spark_task.SdkGenericSparkTask( - task_type=_common_constants.SdkTaskType.SPARK_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - spark_type=spark_type, - task_inputs=inputs, - main_class=main_class or "", - main_application_file=main_application_file or "", - spark_conf=spark_conf or {}, - hadoop_conf=hadoop_conf or {}, - environment=environment or {}, - cache_serializable=cache_serialize, - ) - - -def qubole_spark_task(*args, **kwargs): - """ - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - raise NotImplementedError("Qubole Spark Task is currently not supported in Flyte.") - - -def hive_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - cache_serialize=False, - cls=None, -): - """ - Decorator to create a hive task. This task should output a list of hive queries which are run on a hive cluster. - - This is a 2 step task: - - 1. Generator step runs in the user container and outputs a list of queries. - 2. The list of queries produced in step1 is then submitted to Hive Cluster. The queries are monitored by Flyte - Backend for completion. - Container properties(cpu, gpu, memory, etc) set on this task are only used in step1 above. - - .. code-block:: python - - @inputs(a=Types.Integer) - @hive_task( - cache_version='1', - ) - def test_hive(wf_params, a): - return [ - "SELECT * FROM users_table WHERE user_id=4", - "INSERT INTO users_table VALUES ("user", 5, 4)" - ] - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.common.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param bool interruptible: [optional] boolean describing if task is interruptible. - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - :param dict[Text,Text] environment: Environment variables to set for the execution of the query-generating - container. - :param bool cache_serialize: [optional] boolean describing if instances of this cachable task should be executed - in serial. This means only a single instances executes and other concurrent executions wait for it to complete - and reuse the cached outputs. - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided should be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. Generally, it should be - a subclass of flytekit.common.tasks.hive_task.SdkHiveTask. A user can use this to inject bespoke logic into - the base Flyte programming model. - - :rtype: flytekit.common.tasks.sdk_runnable.SdkHiveTask - """ - - def wrapper(fn): - - return (cls or _sdk_hive_tasks.SdkHiveTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.BATCH_HIVE_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - cluster_label="", - tags=[], - environment=environment or {}, - cache_serializable=cache_serialize, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper - - -def qubole_hive_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - cluster_label=None, - tags=None, - environment=None, - cache_serialize=False, - cls=None, -): - """ - Decorator to create a qubole hive task. This is hive task runs on a qubole cluster, and therefore allows users to - pass cluster labels and qubole query tags. Similar to hive task, this task should output a list of hive queries - that are run on a hive cluster. - - Similar to a hive task, this is also a 2 step task where step2 is run on a qubole hive cluster. Therefore, users can - specify qubole cluster_label and query tags on this task. - - .. code-block:: python - - @inputs(a=Types.Integer) - @qubole_hive_task( - cache_version='1', - cluster_label='cluster_label', - tags=['tag1'], - ) - def test_hive(wf_params, a): - return [ - "SELECT * FROM users_table WHERE user_id=4", - "INSERT INTO users_table VALUES ("user", 5, 4)" - ] - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.common.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param bool interruptible: [optional] boolean describing if task is interruptible. - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - :param cluster_label: The qubole cluster label where the query is to be executed - :param list[Text] tags: User defined tags(key-value pairs) defined by the user for the queries. These tags are - passed to Qubole. - :param dict[Text,Text] environment: Environment variables to set for the execution of the query-generating - container. - :param bool cache_serialize: [optional] boolean describing if instances of this cachable task should be executed - in serial. This means only a single instances executes and other concurrent executions wait for it to complete - and reuse the cached outputs. - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided should be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. Generally, it should be - a subclass of flytekit.common.tasks.hive_task.SdkHiveTask. A user can use this to inject bespoke logic into - the base Flyte programming model. - - :rtype: flytekit.common.tasks.sdk_runnable.SdkHiveTask - """ - - def wrapper(fn): - - return (cls or _sdk_hive_tasks.SdkHiveTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.BATCH_HIVE_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - cluster_label=cluster_label or "", - tags=tags or [], - environment=environment or {}, - cache_serializable=cache_serialize, - ) - - # This is syntactic-sugar, so that when calling this decorator without args, you can either - # do it with () or without any () - if _task_function: - return wrapper(_task_function) - else: - return wrapper - - -def sidecar_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - cache_serialize=False, - pod_spec=None, - primary_container_name=None, - annotations=None, - labels=None, - cls=None, -): - """ - Decorator to create a Sidecar Task definition. This task will execute the primary task alongside the specified - kubernetes PodSpec. Custom primary task container attributes can be defined in the PodSpec by defining a container - whose name matches the primary_container_name. These container attributes will be applied to the container brought - up to execute the primary task definition. - - .. code-block:: python - - def generate_pod_spec_for_task(): - pod_spec = generated_pb2.PodSpec() - secondary_container = generated_pb2.Container( - name="secondary", - image="alpine", - ) - secondary_container.command.extend(["/bin/sh"]) - secondary_container.args.extend(["-c", "echo hi sidecar world > /data/message.txt"]) - shared_volume_mount = generated_pb2.VolumeMount( - name="shared-data", - mountPath="/data", - ) - secondary_container.volumeMounts.extend([shared_volume_mount]) - - primary_container = generated_pb2.Container(name="primary") - primary_container.volumeMounts.extend([shared_volume_mount]) - - pod_spec.volumes.extend([generated_pb2.Volume( - name="shared-data", - volumeSource=generated_pb2.VolumeSource( - emptyDir=generated_pb2.EmptyDirVolumeSource( - medium="Memory", - ) - ) - )]) - pod_spec.containers.extend([primary_container, secondary_container]) - return pod_spec - - @sidecar_task( - pod_spec=generate_pod_spec_for_task(), - primary_container_name="primary", - annotations={"key": "value"}, - labels={"key": "value"}, - ) - def a_sidecar_task(wfparams): - while not os.path.isfile('/data/message.txt'): - time.sleep(5) - - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - .. note:: - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:ex:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param bool interruptible: Specify whether task is interruptible - - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - - TODO: !!! - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - - TODO: !!! - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - - :param bool cache: [optional] boolean describing if the outputs of this task should be discoverable and - re-usable. - - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - - :param bool cache_serialize: [optional] boolean describing if instances of this cachable task should be executed - in serial. This means only a single instances executes and other concurrent executions wait for it to complete - and reuse the cached outputs. - - :param k8s.io.api.core.v1.generated_pb2.PodSpec pod_spec: [optional] PodSpec to bring up alongside task execution. - - :param Text primary_container_name: primary container to monitor for the duration of the task. - - :param dict[Text, Text] annotations: [optional] kubernetes annotations - - :param dict[Text, Text] labels: [optional] kubernetes labels - - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to - inject bespoke logic into the base Flyte programming model. - - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - - """ - - def wrapper(fn): - - return (cls or _sdk_sidecar_tasks.SdkSidecarTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.SIDECAR_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - environment=environment, - cache_serializable=cache_serialize, - pod_spec=pod_spec, - primary_container_name=primary_container_name, - annotations=annotations, - labels=labels, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper - - -def dynamic_sidecar_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - allowed_failure_ratio=None, - max_concurrency=None, - environment=None, - cache_serialize=False, - pod_spec=None, - primary_container_name=None, - annotations=None, - labels=None, - cls=None, -): - """ - Decorator to create a custom dynamic sidecar task definition. Dynamic - tasks should be used to split up work into an arbitrary number of parallel - sub-tasks, or workflows. This task will execute the primary task alongside - the specified kubernetes PodSpec. Custom primary task container attributes - can be defined in the PodSpec by defining a container whose name matches - the primary_container_name. These container attributes will be applied to - the container brought up to execute the primary task definition. - .. code-block:: python - def generate_pod_spec_for_task(): - pod_spec = generated_pb2.PodSpec() - secondary_container = generated_pb2.Container( - name="secondary", - image="alpine", - ) - secondary_container.command.extend(["/bin/sh"]) - secondary_container.args.extend(["-c", "echo hi sidecar world > /data/message.txt"]) - shared_volume_mount = generated_pb2.VolumeMount( - name="shared-data", - mountPath="/data", - ) - secondary_container.volumeMounts.extend([shared_volume_mount]) - primary_container = generated_pb2.Container(name="primary") - primary_container.volumeMounts.extend([shared_volume_mount]) - pod_spec.volumes.extend([generated_pb2.Volume( - name="shared-data", - volumeSource=generated_pb2.VolumeSource( - emptyDir=generated_pb2.EmptyDirVolumeSource( - medium="Memory", - ) - ) - )]) - pod_spec.containers.extend([primary_container, secondary_container]) - return pod_spec - @outputs(out=Types.Integer) - @python_task - def my_sub_task(wf_params, out): - out.set(randint()) - @outputs(out=[Types.Integer]) - @dynamic_sidecar_task( - pod_spec=generate_pod_spec_for_task(), - primary_container_name="primary", - annotations={"a": "a"}, - labels={"b": "b"}, - ) - def my_task(wf_params, out): - out_list = [] - for i in xrange(100): - out_list.append(my_sub_task().outputs.out) - out.set(out_list) - .. note:: - All outputs of a batch task must be a list. This is because the individual outputs of sub-tasks should be - appended into a list. There cannot be aggregation of outputs done in this task. To accomplish aggregation, - it is recommended that a python_task take the outputs of this task as input and do the necessary work. - If a sub-task does not contribute an output, it must be yielded from the task with the `yield` keyword or - returned from the task in a list. If this isn't done, the sub-task will not be executed. - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed. - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - .. note:: - This argument is required to be a non-empty string if `cache` is True. - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - .. note:: - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - :param bool interruptible: [optional] boolean describing if the task is interruptible. - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - .. note:: - This is currently not supported by the platform. - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - TODO: Add links to resource string documentation for Kubernetes - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - .. note:: - This is currently not supported by the platform. - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - :param float allowed_failure_ratio: [optional] float value describing the ratio of sub-tasks that may fail before - the master batch task considers itself a failure. By default, the value is 0 so if any task fails, the master - batch task will be marked a failure. If specified, the value must be between 0 and 1 inclusive. In the event a - non-zero value is specified, downstream tasks must be able to accept None values as outputs from individual - sub-tasks because the output values will be set to None for any sub-task that fails. - :param int max_concurrency: [optional] integer value describing the maximum number of tasks to run concurrently. - This is a stand-in pending better concurrency controls for special use-cases. The existence of this parameter - is not guaranteed between versions and therefore it is NOT recommended that it be used. - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - :param bool cache_serialize: [optional] boolean describing if instances of this cachable task should be executed - in serial. This means only a single instances executes and other concurrent executions wait for it to complete - and reuse the cached outputs. - :param k8s.io.api.core.v1.generated_pb2.PodSpec pod_spec: PodSpec to bring up alongside task execution. - :param Text primary_container_name: primary container to monitor for the duration of the task. - :param dict[Text, Text] annotations: [optional] kubernetes annotations - :param dict[Text, Text] labels: [optional] kubernetes labels - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. Generally, it should be a - subclass of flytekit.common.tasks.sidecar_Task.SdkDynamicSidecarTask. A user can use this parameter to inject bespoke - logic into the base Flyte programming model. - :rtype: flytekit.common.tasks.sidecar_Task.SdkDynamicSidecarTask - """ - - def wrapper(fn): - return (cls or _sdk_sidecar_tasks.SdkDynamicSidecarTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.SIDECAR_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - allowed_failure_ratio=allowed_failure_ratio, - max_concurrency=max_concurrency, - environment=environment, - cache_serializable=cache_serialize, - pod_spec=pod_spec, - primary_container_name=primary_container_name, - annotations=annotations, - labels=labels, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper diff --git a/flytekit/sdk/test_utils.py b/flytekit/sdk/test_utils.py deleted file mode 100644 index 766774bf2c..0000000000 --- a/flytekit/sdk/test_utils.py +++ /dev/null @@ -1,92 +0,0 @@ -from wrapt import decorator as _decorator - -from flytekit.common import utils as _utils -from flytekit.interfaces.data import data_proxy as _data_proxy - - -class LocalTestFileSystem(object): - """ - Context manager for creating a temporary test file system locally for the purpose of unit testing and grabbing - remote objects. This scratch space will be automatically cleaned up as long as sys.exit() is not called from - within the context. This context need only be used in user scripts and tests--all task executions are guaranteed - to have the necessary managed disk context available. - - .. note:: - - This is especially useful when dealing with remote blob-like objects (Blob, CSV, MultiPartBlob, - MultiPartCSV, Schema) as they require backing on disk. Using this context manager creates that disk context - to support the downloads. - - .. note:: - - Blob-like objects can be downloaded to user-specified locations. See documentation for - flytekit.sdk.types.Types for more information. - - .. note:: - - When this context is entered, it overrides any contexts already entered. All blobs will be written to the most - recent entered context. Upon exiting the context, all data associated will be deleted. It is recommended to - only use one LocalTestFileSystem() per test to avoid confusion. - - .. code-block:: python - - with LocalTestFileSystem(): - wf_handle = SdkWorkflowExecution.fetch('project', 'domain', 'name') - with wf_handle.node_executions['my_node'].outputs.blob as reader: - assert reader.read() == "hello!" - """ - - def __init__(self): - self._exit_stack = _utils.ExitStack() - - def __enter__(self): - """ - :rtype: flytekit.common.utils.AutoDeletingTempDir - """ - self._exit_stack.__enter__() - temp_dir = self._exit_stack.enter_context(_utils.AutoDeletingTempDir("local_test_filesystem")) - self._exit_stack.enter_context(_data_proxy.LocalDataContext(temp_dir.name)) - self._exit_stack.enter_context(_data_proxy.LocalWorkingDirectoryContext(temp_dir)) - return temp_dir - - def __exit__(self, exc_type, exc_val, exc_tb): - return self._exit_stack.__exit__(exc_type, exc_val, exc_tb) - - -@_decorator -def flyte_test(fn, _, args, kwargs): - """ - This is a decorator which can be used to annotate test functions. By using this decorator, the necessary local - scratch context will be prepared and then cleaned up upon completion. - - .. code-block:: python - - @inputs(input_blob=Types.Blob) - @outputs(response_blob=Types.Blob) - @python_task - def test_task(wf_params, input_blob, response_blob): - response = Types.Blob() - with response as writer: - with input_blob as reader: - txt = reader.read() - if txt == "Hi": - writer.write("Hello, world!") - elif txt == "Goodnight": - writer.write("Goodnight, moon.") - else: - writer.write("Does not compute".) - response_blob.set(response) - - @flyte_test - def some_test(): - blob = Types.Blob() - with blob as writer: - writer.write("Hi") - - result = test_task.unit_test(input_blob=blob) - - with result['response_blob'] as reader: - assert reader.read() == 'Hello, world!" - """ - with LocalTestFileSystem(): - return fn(*args, **kwargs) diff --git a/flytekit/sdk/types.py b/flytekit/sdk/types.py deleted file mode 100644 index 7b64d1db49..0000000000 --- a/flytekit/sdk/types.py +++ /dev/null @@ -1,504 +0,0 @@ -from flytekit.common.types import blobs as _blobs -from flytekit.common.types import containers as _containers -from flytekit.common.types import helpers as _helpers -from flytekit.common.types import primitives as _primitives -from flytekit.common.types import proto as _proto -from flytekit.common.types import schema as _schema - - -class Types(object): - Integer = _helpers.get_sdk_type_from_literal_type(_primitives.Integer.to_flyte_literal_type()) - """ - Use this to specify a simple integer type. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, A Python int will be received, if set. - 2) Otherwise, a None value will be received. - - As output: - 1) User code may pass an int or long value. - 2) Output can also be nulled with a None value. - - From command-line: - Specify an integer or integer string. - - .. code-block:: python - - @inputs(a=Types.Integer) - @outputs(b=Types.Integer) - @python_task - def double(wf_params, a, b): - b.set(a * 2) - """ - - Float = _helpers.get_sdk_type_from_literal_type(_primitives.Float.to_flyte_literal_type()) - """ - Use this to specify a simple floating point type. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - A Python float will be received, if set. Otherwise, a None value will be received. - - As output: - User code may pass a float value. It can also be nulled with a None value. - - From command-line: - Specify a float or floating-point string. - - .. code-block:: python - - @inputs(a=Types.Float) - @outputs(b=Types.Float) - @python_task - def square(wf_params, a, b): - b.set(a * a) - """ - - String = _helpers.get_sdk_type_from_literal_type(_primitives.String.to_flyte_literal_type()) - """ - Use this to specify a simple string type. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - A Python str (Python 2) or unicode (Python 3) will be received, if set. Otherwise, a None value will be - received. - - As output: - User code may pass a str value (Python 2) or a unicode value (Python 3). It can also be nulled with a None - value. - - From command-line: - Specify a string. - - .. code-block:: python - - @inputs(a=Types.String, b=Types.String) - @outputs(c=Types.String) - @python_task - def concat(wf_params, a, b): - c.set(a + b) - """ - - Boolean = _helpers.get_sdk_type_from_literal_type(_primitives.Boolean.to_flyte_literal_type()) - """ - Use this to specify a simple bool type. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - A Python bool will be received, if set. Otherwise, a None value will be received. - - As output: - User code may pass a bool value. It can also be nulled with a None value. - - From command-line: - Specify 0, 1, true, or false. - - .. code-block:: python - - @inputs(a=Types.Boolean) - @outputs(b=Types.Boolean) - @python_task - def invert(wf_params, a, b): - b.set(not a) - """ - - Datetime = _helpers.get_sdk_type_from_literal_type(_primitives.Datetime.to_flyte_literal_type()) - """ - Use this to specify a simple datetime type. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - A Python timezone-aware datetime.datetime will be received with a UTC time, if set. Otherwise, - a None value will be received. - - As output: - User code may pass a timezone-aware datetime.datetime value. It can also be nulled with a None value. - - From command-line: - Specify a timezone-aware, parsable datestring. i.e. 2019-01-01T00:00+00:00 - - .. note:: - - The engine requires that datetimes be timezone aware. By default, Python datetime.datetime is not timezone - aware. - - .. code-block:: python - - @inputs(a=Types.Datetime) - @outputs(b=Types.Datetime) - @python_task - def tomorrow(wf_params, a, b): - b.set(a + datetime.timedelta(days=1)) - """ - - Timedelta = _helpers.get_sdk_type_from_literal_type(_primitives.Timedelta.to_flyte_literal_type()) - """ - Use this to specify a simple timedelta type. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - A Python datetime.timedelta will be received, if set. Otherwise, a None value will be received. - - As output: - User code may pass a datetime.timedelta value. It can also be nulled with a None value. - - From command-line: - Specify a parsable duration string. i.e. 1h30m24s - - .. code-block:: python - - @inputs(a=Types.Timedelta) - @outputs(b=Types.Timedelta) - @python_task - def hundred_times_longer(wf_params, a, b): - b.set(a * 100) - """ - - Generic = _helpers.get_sdk_type_from_literal_type(_primitives.Generic.to_flyte_literal_type()) - """ - Use this to specify a simple JSON type. The Generic type offer a flexible (but loose) extension to flyte's typing - system by allowing custom types/objects to be passed through. It's strongly recommended for producers & consumers of - entities that produce or consume a Generic type to perform their own expectations checks on the integrity of the - object. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, a Python dict with JSON-ifiable primitives and nested lists or maps. - 2) Otherwise, a None value will be received. - - As output: - 1) User code may pass a Python dict with arbitrarily nested lists and dictionaries. JSON-ifiable - primitives may also be specified. - 2) Output can also be nulled with a None value. - - From command-line: - Specify a JSON string. - - .. code-block:: python - - @inputs(a=Types.Generic) - @outputs(b=Types.Generic) - @python_task - def operate(wf_params, a, b): - if a['operation'] == 'add': - a['value'] += a['operand'] # a['value'] is a number - elif a['operation'] == 'merge': - a['value'].update(a['some']['nested'][0]['field']) - b.set(a) - - # For better readability, it's strongly advised to leverage python's type aliasing. - MyTypeA = Types.Generic - MyTypeB = Types.Generic - - # This makes it clearer that it received a certain type and produces a different one. Other tasks that consume - # MyTypeB should do so in their input declaration. - @inputs(a=MyTypeA) - @outputs(b=MyTypeB) - @python_task - def operate(wf_params, a, b): - if a['operation'] == 'add': - a['value'] += a['operand'] # a['value'] is a number - elif a['operation'] == 'merge': - a['value'].update(a['some']['nested'][0]['field']) - b.set(a) - """ - - Blob = _blobs.Blob - """ - Use this to specify a Blob object which is essentially a managed file. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, a :py:class:`flytekit.common.types.impl.blobs.Blob` object will be received. - 2) If not set, a None value. - - As output: - 1) A user may specify a path string. - 2) A user may construct a :py:class:`flytekit.common.types.impl.blobs.Blob` object and pass it as output. - 3) Output can be nulled with a None value. - - From command-line: - Specify a path to the blob. This path must be accessible from the container when executing--either by - being downloaded from an accessible remote location like s3 or as a local file. - - .. code-block:: python - - @inputs(a=Types.Blob) - @outputs(b=Types.Blob) - @python_task - def copy(wf_params, a, b): - with a as reader: - txt = reader.read() - - out = Types.Blob() # Create at a random location specified in flytekit configuration - with out as writer: - writer.write(txt) - b.set(out) - """ - - CSV = _blobs.CSV - """ - Use this to specify a CSV blob object which is essentially a managed file in the CSV format. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, a :py:class:`flytekit.common.types.impl.blobs.CSV` object will be received. - 2) If not set, a None value. - - As output: - 1) A user may specify a path string. - 2) A user may construct a :py:class:`flytekit.common.types.impl.blobs.CSV` object and pass it as output. - 3) Output can be nulled with a None value. - - From command-line: - Specify a path to the CSV. This path must be accessible from the container when executing--either by - being downloaded from an accessible remote location like s3 or as a local file. - - .. code-block:: python - - @inputs(a=Types.CSV) - @outputs(b=Types.CSV) - @python_task - def copy(wf_params, a, b): - with a as reader: - txt = reader.read() - - out = Types.CSV() # Create at a random location specified in flytekit configuration - with out as writer: - writer.write(txt) - b.set(out) - """ - - MultiPartBlob = _blobs.MultiPartBlob - """ - Use this to specify a multi-part blob object which is essentially a chunked file in a non-recursive directory. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, a :py:class:`flytekit.common.types.impl.blobs.MultiPartBlob` object will be received. - 2) If not set, a None value. - - As output: - 1) A user may specify a path string. - 2) A user may construct a :py:class:`flytekit.common.types.impl.blobs.MultiPartBlob` object and pass it as - output. - 3) Output can be nulled with a None value. - - From command-line: - Specify a path to the multi-part blob. This path must be accessible from the container when - executing--either by being downloaded from an accessible remote location like s3 or as a local file. - - .. code-block:: python - - @inputs(a=Types.MultiPartBlob) - @outputs(b=Types.MultiPartBlob) - @python_task - def concat_then_split(wf_params, a, b): - txt = "" - with a as chunks: - for chunk in chunks: - txt += chunk.read() - - out = Types.MultiPartBlob() # Create at a random location specified in flytekit configuration - with out.create_part('000000') as writer: - writer.write("Chunk1") - with out.create_part('000001') as writer: - writer.write("Chunk2") - b.set(out) - """ - - MultiPartCSV = _blobs.MultiPartCSV - """ - See :py:attr:`flytekit.sdk.types.Types.MultiPartBlob`, but in CSV format - """ - - Schema = staticmethod(_schema.schema_instantiator) - """ - Use this to specify a Schema blob object which is essentially a chunked stream of Parquet dataframes. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - Cast behavior: - 1) A generic schema (specified as `Types.Schema()`) can receive input from any schema type regardless of - column definitions. - 2) A schema can receive as input any schema object as long as the upstream schema has a superset of the - column names defined and the types match for paired columns. Ordering does not matter. - - As input: - 1) If set, a :py:class:`flytekit.common.types.impl.schema.Schema` object will be received. - 2) If not set, a None value. - - As output: - 1) A user may specify a path string to a chunked dataframe non-recursive directory. - 2) A user may construct a :py:class:`flytekit.common.types.impl.schema.Schema` object (with the correct - column definitions) and pass it as output. - 3) Output can be nulled with a None value. - - From command-line: - Specify a path to the schema object. This path must be accessible from the container when - executing--either by being downloaded from an accessible remote location like s3 or as a local file. - - .. code-block:: python - - @inputs(generic=Types.Schema(), typed=Types.Schema([('a', Types.Integer), ('b', Types.Float)])) - @outputs(b=Types.Schema([('a', Types.Integer), ('b', Types.Float)])) - @python_task - def concat_then_split(wf_params, generic, typed,): - with typed as reader: - # Each chunk is loaded as a pandas.DataFrame object - for df in reader.iter_chunks(): - # Operate on the dataframe - - # Create at a random location specified in flytekit configuration - out = Types.Schema([('a', Types.Integer), ('b', Types.Float)])() - with out as writer: - writer.write( - pandas.DataFrame.from_dict( - { - 'a': [1, 2, 3], - 'b': [5.0, 6.0, 7.0] - } - ) - ) - b.set(out) - """ - - Proto = staticmethod(_proto.create_protobuf) - """ - Proto type wraps a protobuf type to provide interoperability between protobuf and flyte typing system. Using this - type, you can define custom input/output variable types of flyte entities and continue to provide strong typing - syntax. Proto type serializes proto objects as binary (leveraging `flyteidl's Binary literal `_). - Binary serialization of protobufs is the most space-efficient serialization form. Because of the way protobufs are - designed, unmarshalling the serialized proto requires access to the corresponding type. In order to use/visualize - the serialized proto, you will generally need to write custom code in the corresponding component. - - .. note:: - - The protobuf Python library should be installed on the PYTHONPATH to ensure the type engine can access the - appropriate Python code to deserialize the protobuf. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, a Python protobuf object of the type specified in the definition. - 2) If not set, a None value. - - As output: - 1) A Python protobuf object matching the type specified by the users. - 2) Set None to null the output. - - From command-line: - A base-64 encoded string of the serialized protobuf. - - .. code-block:: python - - from protos import my_protos_pb2 - - @inputs(a=Types.Proto(my_protos_pb2.Custom)) - @outputs(b=Types.Proto(my_protos_pb2.Custom)) - @python_task - def assert_and_create(wf_params, a, b): - assert a.field1 == 1 - assert a.field2 == 'abc' - b.set( - my_protos_pb2.Custom( - field1=100, - field2='hello' - ) - ) - """ - - GenericProto = staticmethod(_proto.create_generic) - """ - GenericProto type wraps a protobuf type to provide interoperability between protobuf and flyte typing system. Using - this type, you can define custom input/output variable types of flyte entities and continue to provide strong typing - syntax. Proto type serializes proto objects as binary (leveraging `flyteidl's Binary literal `_). - A generic proto is a specialization of the Generic type with added convenience functions to support marshalling/ - unmarshalling of the underlying protobuf object using the protobuf official json marshaller. While GenericProto type - does not produce the most space-efficient representation of protobufs, it's a suitable solution for making protobufs - easily accessible (i.e. humanly readable) in other flyte components (e.g. console, cli... etc.). - - .. note:: - - The protobuf Python library should be installed on the PYTHONPATH to ensure the type engine can access the - appropriate Python code to deserialize the protobuf. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, a Python protobuf object of the type specified in the definition. - 2) If not set, a None value. - - As output: - 1) A Python protobuf object matching the type specified by the users. - 2) Set None to null the output. - - From command-line: - A base-64 encoded string of the serialized protobuf. - - .. code-block:: python - - from protos import my_protos_pb2 - - @inputs(a=Types.GenericProto(my_protos_pb2.Custom)) - @outputs(b=Types.GenericProto(my_protos_pb2.Custom)) - @python_task - def assert_and_create(wf_params, a, b): - assert a.field1 == 1 - assert a.field2 == 'abc' - b.set( - my_protos_pb2.Custom( - field1=100, - field2='hello' - ) - ) - """ - - List = staticmethod(_containers.List) - """ - Use this to specify a list of any type--including nested lists. - - When used with an SDK-decorated method, expect this behavior from the default type engine: - - As input: - 1) If set, a Python list populated with values matching the behavior of the list's sub-type. - 2) If not set, a None value. - - As output: - 1) A Python list containing values adhering to the list's sub-type. - 2) Set None to null the output. - - From command-line: - Specify a valid JSON list string. The sub-values will be checked against the sub-type of the list. - - .. note:: - - Shorthand syntax is supported of the form: `[Types.Integer]` in addition to longhand syntax like - `Types.List(Types.Integer)`. Both forms are equivalent. - - .. note:: - - Lists can be arbitrarily deeply nested, however, the typing must be consistent between all sibling values in a - nested list. Syntax for nesting is `[[[Types.Integer]]]` to create a 3D list of integers. - - .. code-block:: python - - @inputs(a=[Types.Integer]) - @outputs(b=[Types.Integer]) - @python_task - def square_each(wf_params, a, b): - b.set( - [x * x for x in a] - ) - """ diff --git a/flytekit/sdk/workflow.py b/flytekit/sdk/workflow.py deleted file mode 100644 index 34561dc592..0000000000 --- a/flytekit/sdk/workflow.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import Dict - -import six as _six - -import flytekit.common.local_workflow -from flytekit.common import nodes as _nodes -from flytekit.common import promise as _promise -from flytekit.common import workflow as _common_workflow -from flytekit.common.types import helpers as _type_helpers - - -class Input(_promise.Input): - """ - This object should be used to specify inputs. It can be used in conjunction with - :py:meth:`flytekit.common.workflow.workflow` and :py:meth:`flytekit.common.workflow.workflow_class` - """ - - def __init__(self, sdk_type, help=None, **kwargs): - """ - :param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: This is the SDK type necessary to create an - input to this workflow. - :param Text help: An optional help string to describe the input to users. - :param bool required: If set, default must be None - :param T default: If this is not a required input, the value will default to this value. Specify as a kwarg. - """ - super(Input, self).__init__("", _type_helpers.python_std_to_sdk_type(sdk_type), help=help, **kwargs) - - -class Output(flytekit.common.local_workflow.Output): - """ - This object should be used to specify outputs. It can be used in conjunction with - :py:meth:`flytekit.common.workflow.workflow` and :py:meth:`flytekit.common.workflow.workflow_class` - """ - - def __init__(self, value, sdk_type=None, help=None): - """ - :param T value: - :param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: If specified, the value provided must - match this type exactly. If not provided, the SDK will attempt to infer the type. It is recommended - this value be provided as the SDK might not always be able to infer the correct type. - """ - super(Output, self).__init__( - "", - value, - sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None, - help=help, - ) - - -def workflow_class(_workflow_metaclass=None, on_failure=None, disable_default_launch_plan=False, cls=None): - """ - This is a decorator for wrapping class definitions into workflows. - - .. code-block:: python - - @workflow_class - class MyWorkflow(object): - a = Input(Types.Integer, default=100, help="Tell me something") - b = Input(Types.Float, required=True) - first_task = my_task(a=a) - second_task = my_other_task(b=b, c=first_task.outputs.c) - d = Output(node2.outputs.d) - - - :param T _workflow_metaclass: Do NOT specify this parameter directly. This is the class that is being - wrapped by this decorator. - :param flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy on_failure: [Optional] The execution policy - when the workflow detects a failure. - :param bool disable_default_launch_plan: Determines whether to create a default launch plan for the workflow or not. - :param cls: This is the class that will be instantiated from the inputs, outputs, and nodes. This will be used - by users extending the base Flyte programming model. If set, it must be a subclass of - :py:class:`flytekit.common.local_workflow.PythonWorkflow`. - - :rtype: flytekit.common.workflow.SdkWorkflow - """ - - def wrapper(metaclass): - wf = flytekit.common.local_workflow.build_sdk_workflow_from_metaclass( - metaclass, on_failure=on_failure, disable_default_launch_plan=disable_default_launch_plan, cls=cls - ) - return wf - - if _workflow_metaclass is not None: - return wrapper(_workflow_metaclass) - return wrapper - - -def workflow(nodes: Dict[str, _nodes.SdkNode], inputs=None, outputs=None, cls=None, on_failure=None): - """ - This function provides a user-friendly interface for authoring workflows. - - .. code-block:: python - - input_a = Input(Types.Integer, default=100, help="Tell me something") - input_b = Input(Types.Float, required=True) - - node1 = my_task(a=input_a) - node2 = my_other_task(b=input_b, c=node1.outputs.c) - - MyWorkflow = workflow( - workflow_id='my_workflow_id', - inputs={ - 'a': input_a, - 'b': input_b - }, - outputs={ - 'd': Output(node2.outputs.d, sdk_type=Types.Integer, help='This is an integer output') - }, - nodes=[ - node1, - node2 - ] - ) - - :param dict[Text,flytekit.common.nodes.SdkNode] nodes: A list of nodes to put inside the workflow. - :param dict[Text,Input] inputs: [Optional] A dictionary of input descriptors for the workflow. - :param dict[Text,Output] outputs: [Optional] A dictionary of output descriptors for a workflow. - :param T cls: This is the class that will be instantiated from the inputs, outputs, and nodes. This will be used - by users extending the base Flyte programming model. If set, it must be a subclass of - :py:class:`flytekit.common.local_workflow.PythonWorkflow`. - :param flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy on_failure: [Optional] The execution policy when the workflow detects a failure. - - :rtype: flytekit.common.local_workflow.SdkRunnableWorkflow - """ - # TODO: Why does Pycharm complain about nodes? - wf = (cls or flytekit.common.local_workflow.SdkRunnableWorkflow).construct_from_class_definition( - inputs=[v.rename_and_return_reference(k) for k, v in sorted(_six.iteritems(inputs or {}))], - outputs=[v.rename_and_return_reference(k) for k, v in sorted(_six.iteritems(outputs or {}))], - nodes=[v.assign_id_and_return(k) for k, v in sorted(_six.iteritems(nodes))], - metadata=_common_workflow._workflow_models.WorkflowMetadata(on_failure=on_failure), - ) - return wf diff --git a/flytekit/testing/__init__.py b/flytekit/testing/__init__.py index 15340a7014..bb75358198 100644 --- a/flytekit/testing/__init__.py +++ b/flytekit/testing/__init__.py @@ -16,5 +16,4 @@ """ -from flytekit.common.tasks.sdk_runnable import SecretsManager from flytekit.core.testing import patch, task_mock diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index f53a9b0b96..f9b9fc6665 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -6,11 +6,12 @@ import checksumdir -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.interfaces.data.data_proxy import Data as _Data +from flytekit.core.context_manager import FlyteContextManager _tmp_versions_dir = "tmp/versions" +file_access = FlyteContextManager.current_context().file_access + def compute_digest(source_dir: _os.PathLike) -> str: """ @@ -70,7 +71,7 @@ def upload_package(source_dir: _os.PathLike, identifier: str, remote_location: s print("Local marker for identifier {} already exists, skipping upload".format(identifier)) return full_remote_path - if _Data.data_exists(full_remote_path): + if file_access.exists(full_remote_path): print("Remote file {} already exists, skipping upload".format(full_remote_path)) _write_marker(marker) return full_remote_path @@ -82,7 +83,7 @@ def upload_package(source_dir: _os.PathLike, identifier: str, remote_location: s if dry_run: print("Would upload {} to {}".format(fp.name, full_remote_path)) else: - _Data.put_data(fp.name, full_remote_path) + file_access.put_data(fp.name, full_remote_path) print("Uploaded {} to {}".format(fp.name, full_remote_path)) # Finally, touch the marker file so we have a flag in the future to avoid re-uploading the package dir as an @@ -97,7 +98,7 @@ def download_distribution(additional_distribution: str, destination: str): :param Text additional_distribution: :param _os.PathLike destination: """ - _data_proxy.Data.get_data(additional_distribution, destination) + file_access.get_data(additional_distribution, destination) tarfile_name = _os.path.basename(additional_distribution) file_suffix = _Path(tarfile_name).suffixes if len(file_suffix) != 2 or file_suffix[0] != ".tar" or file_suffix[1] != ".gz": diff --git a/flytekit/tools/lazy_loader.py b/flytekit/tools/lazy_loader.py deleted file mode 100644 index a174769673..0000000000 --- a/flytekit/tools/lazy_loader.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -import importlib as _importlib -import sys as _sys -import types as _types -from typing import List - - -class LazyLoadPlugin(object): - LAZY_LOADING_PLUGINS = {} - - def __init__(self, plugin_name, plugin_requirements, related_modules: List[_LazyLoadModule]): - """ - :param Text plugin_name: - :param list[Text] plugin_requirements: - :param list[LazyLoadModule] related_modules: - """ - type(self).LAZY_LOADING_PLUGINS[plugin_name] = plugin_requirements - for m in related_modules: - type(m).tag_with_plugin(plugin_name) - - @classmethod - def get_extras_require(cls): - """ - :rtype: dict[Text,list[Text]] - """ - d = cls.LAZY_LOADING_PLUGINS.copy() - all_plugins_spark2 = [] - all_plugins_spark3 = [] - for k in d: - # Default to Spark 2.4.x in all-spark2 and Spark 3.x in all-spark3. - if k != "spark3": - all_plugins_spark2.extend(d[k]) - if k != "spark": - all_plugins_spark3.extend(d[k]) - - d["all-spark2.4"] = all_plugins_spark2 - d["all-spark3"] = all_plugins_spark3 - # all points to Spark 3.x. - # Spark 2.4 to be fully removed in a future release. - d["all"] = all_plugins_spark3 - return d - - -def lazy_load_module(module: str) -> _types.ModuleType: - """ - :param Text module: - :rtype: _types.ModuleType - """ - - class LazyLoadModule(_LazyLoadModule): - _module = module - _lazy_submodules = dict() - _plugins = [] - - return LazyLoadModule(module) - - -class _LazyLoadModule(_types.ModuleType): - _ERROR_MSG_FMT = ( - "Attempting to use a plugin functionality that requires module " - "`{module}`, but it couldn't be loaded. Please pip install at least one of {plugins} or " - "`flytekit[all]` to get these dependencies.\n" - "\n" - "Original message: {msg}" - ) - - @classmethod - def _load(cls): - module = _sys.modules.get(cls._module) - if not module: - try: - module = _importlib.import_module(cls._module) - except ImportError as e: - raise ImportError(cls._ERROR_MSG_FMT.format(module=cls._module, plugins=cls._plugins, msg=e)) - return module - - def __getattribute__(self, item): - if item in type(self)._lazy_submodules: - return type(self)._lazy_submodules[item] - m = type(self)._load() - return getattr(m, item) - - def __setattr__(self, key, value): - m = type(self)._load() - return setattr(m, key, value) - - @classmethod - def _add_sub_module(cls, submodule): - """ - Add a submodule. - :param Text submodule: This should be a single submodule. Do NOT include periods - :rtype: LazyLoadModule - """ - m = cls._lazy_submodules.get(submodule) - if not m: - m = cls._lazy_submodules[submodule] = lazy_load_module("{}.{}".format(cls._module, submodule)) - return m - - @classmethod - def add_sub_module(cls, submodule): - """ - Add a submodule. - :param Text submodule: If periods are included, it will be added recursively - :rtype: LazyLoadModule - """ - parts = submodule.split(".", 1) - m = cls._add_sub_module(parts[0]) - if len(parts) > 1: - m = type(m).add_sub_module(parts[1]) - return m - - @classmethod - def tag_with_plugin(cls, p: LazyLoadPlugin): - """ - :param LazyLoadPlugin p: - """ - cls._plugins.append(p) diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py index 182b11de72..bc0c46bbbf 100644 --- a/flytekit/tools/module_loader.py +++ b/flytekit/tools/module_loader.py @@ -5,10 +5,6 @@ import sys from typing import Any, Iterator, List, Union -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.local_workflow import SdkRunnableWorkflow as _SdkRunnableWorkflow -from flytekit.common.mixins import registerable as _registerable - def iterate_modules(pkgs): for package_name in pkgs: @@ -56,71 +52,6 @@ def load_workflow_modules(pkgs): pass -def _topo_sort_helper( - obj, - entity_to_module_key, - visited, - recursion_set, - recursion_stack, - include_entities, - ignore_entities, - detect_unreferenced_entities, -): - visited.add(obj) - recursion_stack.append(obj) - if obj in recursion_set: - raise _user_exceptions.FlyteAssertion( - "A cyclical dependency was detected during topological sort of entities. " - "Cycle path was:\n\n\t{}".format("\n\t".join(p for p in recursion_stack[recursion_set[obj] :])) - ) - recursion_set[obj] = len(recursion_stack) - 1 - - if isinstance(obj, _registerable.HasDependencies): - for upstream in obj.upstream_entities: - if upstream.has_registered: - continue - if upstream not in visited: - for m1, k1, o1 in _topo_sort_helper( - upstream, - entity_to_module_key, - visited, - recursion_set, - recursion_stack, - include_entities, - ignore_entities, - detect_unreferenced_entities, - ): - if not o1.has_registered: - yield m1, k1, o1 - - recursion_stack.pop() - del recursion_set[obj] - - if isinstance(obj, include_entities) or not isinstance(obj, ignore_entities): - if obj in entity_to_module_key: - yield entity_to_module_key[obj] + (obj,) - elif detect_unreferenced_entities: - raise _user_exceptions.FlyteAssertion( - f"An entity ({obj.id}) was not found in modules accessible from the workflow packages configuration. Please " - f"ensure that entities in '{obj.instantiated_in}' are moved to a configured packaged, or adjust the configuration." - ) - - -def _get_entity_to_module(pkgs): - entity_to_module_key = {} - for m in iterate_modules(pkgs): - for k in dir(m): - o = m.__dict__[k] - if isinstance(o, _registerable.RegisterableEntity) and not o.has_registered: - if o.instantiated_in == m.__name__: - entity_to_module_key[o] = (m, k) - if isinstance(o, _SdkRunnableWorkflow) and o.should_create_default_launch_plan: - # SDK should create a default launch plan for a workflow. This is a special-case to simplify - # authoring of workflows. - entity_to_module_key[o.create_launch_plan()] = (m, k) - return entity_to_module_key - - def load_module_object_for_type(pkgs, t, additional_path=None): def iterate(): entity_to_module_key = {} @@ -138,66 +69,33 @@ def iterate(): return iterate() -def iterate_registerable_entities_in_order( +def load_object_from_module(object_location: str) -> Any: + """ + # TODO: Handle corner cases, like where the first part is [] maybe + """ + class_obj = object_location.split(".") + class_obj_mod = class_obj[:-1] # e.g. ['flytekit', 'core', 'python_auto_container'] + class_obj_key = class_obj[-1] # e.g. 'default_task_class_obj' + class_obj_mod = importlib.import_module(".".join(class_obj_mod)) + return getattr(class_obj_mod, class_obj_key) + + +def trigger_loading( pkgs, local_source_root=None, - ignore_entities=None, - include_entities=None, - detect_unreferenced_entities=True, ): """ This function will iterate all discovered entities in the given package list. It will then attempt to topologically sort such that any entity with a dependency on another comes later in the list. Note that workflows can reference other workflows and launch plans. + :param list[Text] pkgs: :param Text local_source_root: - :param set[type] ignore_entities: If specified, ignore these entities while doing a topological sort. All other - entities will be taken. Only one of ignore_entities or include_entities can be set. - :param set[type] include_entities: If specified, include these entities while doing a topological sort. All - other entities will be ignored. Only one of ignore_entities or include_entities can be set. - :param bool detect_unreferenced_entities: If true, we will raise exceptions on entities not included in the package - configuration. - :rtype: module, Text, flytekit.common.mixins.registerable.RegisterableEntity """ - if ignore_entities and include_entities: - raise _user_exceptions.FlyteAssertion("ignore_entities and include_entities cannot both be set") - elif not ignore_entities and not include_entities: - include_entities = (object,) - ignore_entities = tuple() - else: - ignore_entities = tuple(list(ignore_entities or set([object]))) - include_entities = tuple(list(include_entities or set())) - if local_source_root is not None: with add_sys_path(local_source_root): - entity_to_module_key = _get_entity_to_module(pkgs) + for _ in iterate_modules(pkgs): + ... else: - entity_to_module_key = _get_entity_to_module(pkgs) - - visited = set() - for o in entity_to_module_key.keys(): - if o not in visited: - recursion_set = dict() - recursion_stack = [] - for m, k, o2 in _topo_sort_helper( - o, - entity_to_module_key, - visited, - recursion_set, - recursion_stack, - include_entities, - ignore_entities, - detect_unreferenced_entities=detect_unreferenced_entities, - ): - yield m, k, o2 - - -def load_object_from_module(object_location: str) -> Any: - """ - # TODO: Handle corner cases, like where the first part is [] maybe - """ - class_obj = object_location.split(".") - class_obj_mod = class_obj[:-1] # e.g. ['flytekit', 'core', 'python_auto_container'] - class_obj_key = class_obj[-1] # e.g. 'default_task_class_obj' - class_obj_mod = importlib.import_module(".".join(class_obj_mod)) - return getattr(class_obj_mod, class_obj_key) + for _ in iterate_modules(pkgs): + ... diff --git a/flytekit/common/translator.py b/flytekit/tools/translator.py similarity index 99% rename from flytekit/common/translator.py rename to flytekit/tools/translator.py index deca82f7fd..83d724e4ce 100644 --- a/flytekit/common/translator.py +++ b/flytekit/tools/translator.py @@ -1,8 +1,7 @@ from collections import OrderedDict from typing import Callable, Dict, List, Optional, Tuple, Union -from flytekit.common import constants as _common_constants -from flytekit.common.utils import _dnsify +from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode from flytekit.core.context_manager import SerializationSettings @@ -11,6 +10,7 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask +from flytekit.core.utils import _dnsify from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase from flytekit.models import common as _common_models from flytekit.models import interface as interface_models diff --git a/flytekit/type_engines/__init__.py b/flytekit/type_engines/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/type_engines/common.py b/flytekit/type_engines/common.py deleted file mode 100644 index 33927325ce..0000000000 --- a/flytekit/type_engines/common.py +++ /dev/null @@ -1,31 +0,0 @@ -import abc as _abc - - -class TypeEngine(object, metaclass=_abc.ABCMeta): - @_abc.abstractmethod - def python_std_to_sdk_type(self, t): - """ - Converts a standard format for specifying types in Python to the Flyte typing structure. - :param T t: User input. Usually of the form: Types.Integer, [Types.Integer], {Types.String: - Types.Integer}, etc. - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - pass - - @_abc.abstractmethod - def get_sdk_type_from_literal_type(self, literal_type): - """ - Takes the Flyte spec language and converts to an SDK object. - :param flytekit.models.types.LiteralType literal_type: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - pass - - @_abc.abstractmethod - def infer_sdk_type_from_literal(self, literal): - """ - From a literal value, we infer the correct SDK type. - :param flytekit.models.literals.Literal literal: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - pass diff --git a/flytekit/type_engines/default/__init__.py b/flytekit/type_engines/default/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/type_engines/default/flyte.py b/flytekit/type_engines/default/flyte.py deleted file mode 100644 index 0ec3a4c982..0000000000 --- a/flytekit/type_engines/default/flyte.py +++ /dev/null @@ -1,193 +0,0 @@ -import importlib as _importer -from typing import Type - -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.common.types import blobs as _blobs -from flytekit.common.types import containers as _container_types -from flytekit.common.types import helpers as _helpers -from flytekit.common.types import primitives as _primitive_types -from flytekit.common.types import proto as _proto -from flytekit.common.types import schema as _schema -from flytekit.models import types as _literal_type_models -from flytekit.models.core import types as _core_types - - -def _load_type_from_tag(tag: str) -> Type: - """ - Loads python type from tag - """ - - if "." not in tag: - raise _user_exceptions.FlyteValueException( - tag, - "Protobuf tag must include at least one '.' to delineate package and object name.", - ) - - module, name = tag.rsplit(".", 1) - try: - pb_module = _importer.import_module(module) - except ImportError: - raise _user_exceptions.FlyteAssertion( - "Could not resolve the protobuf definition @ {}. Is the protobuf library installed?".format(module) - ) - - if not hasattr(pb_module, name): - raise _user_exceptions.FlyteAssertion("Could not find the protobuf named: {} @ {}.".format(name, module)) - - return getattr(pb_module, name) - - -def _proto_sdk_type_from_tag(tag): - """ - :param Text tag: - :rtype: _proto.Protobuf - """ - return _proto.create_protobuf(_load_type_from_tag(tag)) - - -def _generic_proto_sdk_type_from_tag(tag: str) -> Type[_proto.GenericProtobuf]: - """ - :param Text tag: - :rtype: _proto.GenericProtobuf - """ - - return _proto.create_generic(_load_type_from_tag(tag)) - - -class FlyteDefaultTypeEngine(object): - _SIMPLE_TYPE_LOOKUP_TABLE = { - _literal_type_models.SimpleType.INTEGER: _primitive_types.Integer, - _literal_type_models.SimpleType.FLOAT: _primitive_types.Float, - _literal_type_models.SimpleType.BOOLEAN: _primitive_types.Boolean, - _literal_type_models.SimpleType.DATETIME: _primitive_types.Datetime, - _literal_type_models.SimpleType.DURATION: _primitive_types.Timedelta, - _literal_type_models.SimpleType.NONE: _base_sdk_types.Void, - _literal_type_models.SimpleType.STRING: _primitive_types.String, - _literal_type_models.SimpleType.STRUCT: _primitive_types.Generic, - } - - def python_std_to_sdk_type(self, t): - """ - :param T t: User input. Should be of the form: Types.Integer, [Types.Integer], {Types.String: - Types.Integer}, etc. - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - if isinstance(t, list): - if len(t) != 1: - raise _user_exceptions.FlyteAssertion( - "When specifying a list type, there must be exactly one element in " - "the list describing the contained type." - ) - return _container_types.List(_helpers.python_std_to_sdk_type(t[0])) - elif isinstance(t, dict): - raise _user_exceptions.FlyteAssertion("Map types are not yet implemented.") - elif isinstance(t, _base_sdk_types.FlyteSdkType): - return t - else: - raise _user_exceptions.FlyteTypeException( - type(t), - _base_sdk_types.FlyteSdkType, - additional_msg="Should be of form similar to: Types.Integer, [Types.Integer], {Types.String: " - "Types.Integer}", - received_value=t, - ) - - def get_sdk_type_from_literal_type(self, literal_type): - """ - :param flytekit.models.types.LiteralType literal_type: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - if literal_type.collection_type is not None: - return _container_types.List(_helpers.get_sdk_type_from_literal_type(literal_type.collection_type)) - elif literal_type.map_value_type is not None: - raise NotImplementedError("TODO: Implement map") - elif literal_type.schema is not None: - return _schema.schema_instantiator_from_proto(literal_type.schema) - elif literal_type.blob is not None: - return self._get_blob_impl_from_type(literal_type.blob) - elif literal_type.simple is not None: - if ( - literal_type.simple == _literal_type_models.SimpleType.BINARY - and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata - ): - return _proto_sdk_type_from_tag(literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) - if ( - literal_type.simple == _literal_type_models.SimpleType.STRUCT - and literal_type.metadata - and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata - ): - return _generic_proto_sdk_type_from_tag(literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) - sdk_type = self._SIMPLE_TYPE_LOOKUP_TABLE.get(literal_type.simple) - if sdk_type is None: - raise NotImplementedError( - "We haven't implemented this type yet: Simple type={}".format(literal_type.simple) - ) - return sdk_type - else: - raise _system_exceptions.FlyteSystemAssertion( - "An unrecognized literal type was received: {}".format(literal_type) - ) - - def infer_sdk_type_from_literal(self, literal): # noqa - """ - :param flytekit.models.literals.Literal literal: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - if literal.collection is not None: - if len(literal.collection.literals) > 0: - sdk_type = _container_types.List(_helpers.infer_sdk_type_from_literal(literal.collection.literals[0])) - else: - sdk_type = _container_types.List(_base_sdk_types.Void) - elif literal.map is not None: - raise NotImplementedError("TODO: Implement map") - elif literal.scalar.blob is not None: - sdk_type = self._get_blob_impl_from_type(literal.scalar.blob.metadata.type) - elif literal.scalar.none_type is not None: - sdk_type = _base_sdk_types.Void - elif literal.scalar.schema is not None: - sdk_type = _schema.schema_instantiator_from_proto(literal.scalar.schema.type) - elif literal.scalar.error is not None: - raise NotImplementedError("TODO: Implement error from literal map") - elif literal.scalar.generic is not None: - sdk_type = _primitive_types.Generic - elif literal.scalar.binary is not None: - if literal.scalar.binary.tag.startswith(_proto.Protobuf.TAG_PREFIX): - sdk_type = _proto_sdk_type_from_tag(literal.scalar.binary.tag[len(_proto.Protobuf.TAG_PREFIX) :]) - else: - raise NotImplementedError("TODO: Binary is only supported for protobuf types currently") - elif literal.scalar.primitive.boolean is not None: - sdk_type = _primitive_types.Boolean - elif literal.scalar.primitive.datetime is not None: - sdk_type = _primitive_types.Datetime - elif literal.scalar.primitive.duration is not None: - sdk_type = _primitive_types.Timedelta - elif literal.scalar.primitive.float_value is not None: - sdk_type = _primitive_types.Float - elif literal.scalar.primitive.integer is not None: - sdk_type = _primitive_types.Integer - elif literal.scalar.primitive.string_value is not None: - sdk_type = _primitive_types.String - else: - raise _system_exceptions.FlyteSystemAssertion("Received unknown literal: {}".format(literal)) - return sdk_type - - def _get_blob_impl_from_type(self, blob_type): - """ - :param flytekit.models.core.types.BlobType blob_type: - :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType - """ - if blob_type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE: - if blob_type.format == "csv": - return _blobs.CSV - else: - return _blobs.Blob - elif blob_type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART: - if blob_type.format == "csv": - return _blobs.MultiPartCSV - else: - return _blobs.MultiPartBlob - raise _system_exceptions.FlyteSystemAssertion( - "Flyte's base type engine does not support this type of blob. Value: {}".format(blob_type) - ) diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index ca34d8a23f..b4fd001180 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -10,6 +10,7 @@ from typing import Type import numpy as _np +import pandas from dataclasses_json import config, dataclass_json from marshmallow import fields @@ -17,7 +18,6 @@ from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType -from flytekit.plugins import pandas T = typing.TypeVar("T") diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py index 629c8c588d..ae805d0e81 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py @@ -9,10 +9,9 @@ from google.protobuf.json_format import MessageToDict from flytekit import FlyteContext -from flytekit.common.types import primitives from flytekit.extend import DictTransformer, PythonTask, SerializationSettings, TypeEngine, TypeTransformer from flytekit.models.literals import Literal -from flytekit.models.types import LiteralType +from flytekit.models.types import LiteralType, SimpleType from .models import hpo_job as _hpo_job_model from .models import parameter_ranges as _params @@ -107,7 +106,7 @@ def __init__(self): super().__init__("sagemaker-hpojobconfig-transformer", _hpo_job_model.HyperparameterTuningJobConfig) def get_literal_type(self, t: Type[_hpo_job_model.HyperparameterTuningJobConfig]) -> LiteralType: - return primitives.Generic.to_flyte_literal_type() + return LiteralType(simple=SimpleType.STRUCT, metadata=None) def to_literal( self, @@ -139,7 +138,7 @@ def __init__(self): super().__init__("sagemaker-paramrange-transformer", _params.ParameterRangeOneOf) def get_literal_type(self, t: Type[_params.ParameterRangeOneOf]) -> LiteralType: - return primitives.Generic.to_flyte_literal_type() + return LiteralType(simple=SimpleType.STRUCT, metadata=None) def to_literal( self, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py index 4328749d72..0df8f42dba 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py @@ -2,7 +2,7 @@ from flyteidl.plugins.sagemaker import parameter_ranges_pb2 as _idl_parameter_ranges -from flytekit.common.exceptions import user +from flytekit.exceptions import user from flytekit.models import common as _common diff --git a/plugins/flytekit-aws-sagemaker/tests/test_hpo.py b/plugins/flytekit-aws-sagemaker/tests/test_hpo.py index 28a226e696..e52994c664 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_hpo.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_hpo.py @@ -24,7 +24,7 @@ from flytekitplugins.awssagemaker.training import SagemakerBuiltinAlgorithmsTask, SagemakerTrainingJobConfig from flytekit import FlyteContext -from flytekit.common.types.primitives import Generic +from flytekit.models.types import LiteralType, SimpleType from .test_training import _get_reg_settings @@ -93,7 +93,7 @@ def test_hpo_for_builtin(): def test_hpoconfig_transformer(): t = HPOTuningJobConfigTransformer() - assert t.get_literal_type(HyperparameterTuningJobConfig) == Generic.to_flyte_literal_type() + assert t.get_literal_type(HyperparameterTuningJobConfig) == LiteralType(simple=SimpleType.STRUCT) o = HyperparameterTuningJobConfig( tuning_strategy=1, tuning_objective=HyperparameterTuningObjective( @@ -113,7 +113,7 @@ def test_hpoconfig_transformer(): def test_parameter_ranges_transformer(): t = ParameterRangesTransformer() - assert t.get_literal_type(ParameterRangeOneOf) == Generic.to_flyte_literal_type() + assert t.get_literal_type(ParameterRangeOneOf) == LiteralType(simple=SimpleType.STRUCT) o = ParameterRangeOneOf(param=IntegerParameterRange(10, 0, 1)) ctx = FlyteContext.current_context() lit = t.to_literal(ctx, python_val=o, python_type=ParameterRangeOneOf, expected=None) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_training.py b/plugins/flytekit-aws-sagemaker/tests/test_training.py index 3e6ada1ff5..a48d3c9f39 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_training.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_training.py @@ -14,7 +14,7 @@ import flytekit from flytekit import task -from flytekit.common.tasks.sdk_runnable import ExecutionParameters +from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import Image, ImageConfig, SerializationSettings diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index 05b7cdd5e1..6cbf7d57fa 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -5,7 +5,7 @@ from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements from flytekit import FlyteContext, PythonFunctionTask -from flytekit.common.exceptions import user as _user_exceptions +from flytekit.exceptions import user as _user_exceptions from flytekit.extend import Promise, SerializationSettings, TaskPlugins from flytekit.models import task as _task_models diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index ed0229f02b..82a0bcf6f8 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -8,10 +8,10 @@ from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1VolumeMount from flytekit import Resources, TaskMetadata, dynamic, map_task, task -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import FastSerializationSettings from flytekit.extend import ExecutionState, Image, ImageConfig, SerializationSettings +from flytekit.tools.translator import get_serializable def get_pod_spec(): diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 1a7845a8fe..4b1e183952 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -11,7 +11,7 @@ from nbconvert import HTMLExporter from flytekit import FlyteContext, PythonInstanceTask -from flytekit.common.tasks.sdk_runnable import ExecutionParameters +from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.models.literals import LiteralMap from flytekit.types.file import HTMLPage, PythonNotebook diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index d03949f1ac..53b1620331 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -1,10 +1,17 @@ +import enum import typing from flyteidl.plugins import spark_pb2 as _spark_task -from flytekit.common.exceptions import user as _user_exceptions +from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common -from flytekit.sdk.spark_types import SparkType as _spark_type + + +class SparkType(enum.Enum): + PYTHON = 1 + SCALA = 2 + JAVA = 3 + R = 4 class SparkJob(_common.FlyteIdlEntity): @@ -102,13 +109,13 @@ def to_flyte_idl(self): :rtype: flyteidl.plugins.spark_pb2.SparkJob """ - if self.spark_type == _spark_type.PYTHON: + if self.spark_type == SparkType.PYTHON: application_type = _spark_task.SparkApplication.PYTHON - elif self.spark_type == _spark_type.JAVA: + elif self.spark_type == SparkType.JAVA: application_type = _spark_task.SparkApplication.JAVA - elif self.spark_type == _spark_type.SCALA: + elif self.spark_type == SparkType.SCALA: application_type = _spark_task.SparkApplication.SCALA - elif self.spark_type == _spark_type.R: + elif self.spark_type == SparkType.R: application_type = _spark_task.SparkApplication.R else: raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") @@ -129,13 +136,13 @@ def from_flyte_idl(cls, pb2_object): :rtype: SparkJob """ - application_type = _spark_type.PYTHON + application_type = SparkType.PYTHON if pb2_object.type == _spark_task.SparkApplication.JAVA: - application_type = _spark_type.JAVA + application_type = SparkType.JAVA elif pb2_object.type == _spark_task.SparkApplication.SCALA: - application_type = _spark_type.SCALA + application_type = SparkType.SCALA elif pb2_object.type == _spark_task.SparkApplication.R: - application_type = _spark_type.R + application_type = SparkType.R return cls( type=application_type, diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index bb20da8b48..37ae03e913 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -7,11 +7,10 @@ from pyspark.sql import SparkSession from flytekit import FlyteContextManager, PythonFunctionTask -from flytekit.common.tasks.sdk_runnable import ExecutionParameters +from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, SerializationSettings, TaskPlugins -from flytekit.sdk.spark_types import SparkType -from .models import SparkJob +from .models import SparkJob, SparkType @dataclass diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 4fc80d541f..dfda716a11 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -3,7 +3,7 @@ import flytekit from flytekit import task -from flytekit.common.tasks.sdk_runnable import ExecutionParameters +from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import Image, ImageConfig, SerializationSettings diff --git a/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py b/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py index 404ef2fb23..93104eabdd 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py +++ b/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py @@ -1,8 +1,8 @@ from collections import OrderedDict -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import Image, ImageConfig +from flytekit.tools.translator import get_serializable from .test_task import tk as not_tk diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 2c853f7811..167c8e796d 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -11,8 +11,8 @@ from flytekitplugins.sqlalchemy.task import SQLAlchemyTaskExecutor from flytekit import kwtypes, task, workflow +from flytekit.core.context_manager import SecretsManager from flytekit.models.security import Secret -from flytekit.testing import SecretsManager from flytekit.types.schema import FlyteSchema tk = SQLAlchemyTask( diff --git a/setup.py b/setup.py index bb5f9d1f1c..ae028038bb 100644 --- a/setup.py +++ b/setup.py @@ -2,45 +2,16 @@ from setuptools import find_packages, setup # noqa -# from flytekit.tools.lazy_loader import LazyLoadPlugin # noqa -# extras_require = LazyLoadPlugin.get_extras_require() - MIN_PYTHON_VERSION = (3, 7) CURRENT_PYTHON = sys.version_info[:2] -if CURRENT_PYTHON == (3, 6): - print( - f"Flytekit native typed API is supported for python versions {MIN_PYTHON_VERSION}+, Python 3.6 is supported" - f" only for legacy Flytekit API. This will be deprecated when Python 3.6 reaches end of life (Dec 23rd, 2021)," - f" we recommend migrating to the new API" - ) -elif CURRENT_PYTHON < MIN_PYTHON_VERSION: +if CURRENT_PYTHON < MIN_PYTHON_VERSION: print( f"Flytekit API is only supported for Python version is {MIN_PYTHON_VERSION}+. Detected you are on" f" version {CURRENT_PYTHON}, installation will not proceed!" ) sys.exit(-1) -spark = ["pyspark>=2.4.0,<3.0.0"] -spark3 = ["pyspark>=3.0.0"] -sidecar = ["k8s-proto>=0.0.3,<1.0.0"] -schema = ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>=4.0.0"] -hive_sensor = ["hmsclient>=0.0.1,<1.0.0"] -notebook = ["papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0,<6.0.0"] -sagemaker = ["sagemaker-training>=3.6.2,<4.0.0"] - -all_but_spark = sidecar + schema + hive_sensor + notebook + sagemaker - -extras_require = { - "spark": spark, - "spark3": spark3, - "sidecar": sidecar, - "schema": schema, - "hive_sensor": hive_sensor, - "notebook": notebook, - "sagemaker": sagemaker, - "all-spark2.4": spark + all_but_spark, - "all": spark3 + all_but_spark, -} +extras_require = {} __version__ = "0.0.0+develop" @@ -71,7 +42,7 @@ "click>=6.6,<8.0", "croniter>=0.3.20,<4.0.0", "deprecated>=1.0,<2.0", - "python-dateutil<=2.8.1,>=2.1", + "python-dateutil>=2.1", "grpcio>=1.3.0,<2.0", "protobuf>=3.6.1,<4", "python-json-logger>=2.0.0", @@ -80,7 +51,6 @@ "keyring>=18.0.1", "requests>=2.18.4,<3.0.0", "responses>=0.10.7", - "six>=1.9.0,<2.0.0", "sortedcontainers>=1.5.9<3.0.0", "statsd>=3.0.0,<4.0.0", "urllib3>=1.22,<2.0.0", @@ -104,7 +74,7 @@ "flytekit/bin/entrypoint.py", ], license="apache2", - python_requires=">=3.6", + python_requires=">=3.7", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index 160ff2a53f..27bd3ea1a7 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -1,10 +1,6 @@ from datetime import timedelta from itertools import product -from six.moves import range - -from flytekit.common.types.impl import blobs as _blob_impl -from flytekit.common.types.impl import schema as _schema_impl from flytekit.models import interface, literals, security, task, types from flytekit.models.core import identifier from flytekit.models.core import types as _core_types @@ -174,74 +170,6 @@ timedelta(seconds=5), ), (literals.Scalar(none_type=literals.Void()), None), - ( - literals.Scalar( - blob=literals.Blob( - literals.BlobMetadata(_core_types.BlobType("csv", _core_types.BlobType.BlobDimensionality.SINGLE)), - "s3://some/where", - ) - ), - _blob_impl.Blob("s3://some/where", format="csv"), - ), - ( - literals.Scalar( - blob=literals.Blob( - literals.BlobMetadata(_core_types.BlobType("", _core_types.BlobType.BlobDimensionality.SINGLE)), - "s3://some/where", - ) - ), - _blob_impl.Blob("s3://some/where"), - ), - ( - literals.Scalar( - blob=literals.Blob( - literals.BlobMetadata(_core_types.BlobType("csv", _core_types.BlobType.BlobDimensionality.MULTIPART)), - "s3://some/where/", - ) - ), - _blob_impl.MultiPartBlob("s3://some/where/", format="csv"), - ), - ( - literals.Scalar( - blob=literals.Blob( - literals.BlobMetadata(_core_types.BlobType("", _core_types.BlobType.BlobDimensionality.MULTIPART)), - "s3://some/where/", - ) - ), - _blob_impl.MultiPartBlob("s3://some/where/"), - ), - ( - literals.Scalar( - schema=literals.Schema( - "s3://some/where/", - types.SchemaType( - [ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - ] - ), - ) - ), - _schema_impl.Schema( - "s3://some/where/", - _schema_impl.SchemaType.promote_from_model( - types.SchemaType( - [ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - ] - ) - ), - ), - ), ] LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE = [ diff --git a/tests/flytekit/common/task_definitions.py b/tests/flytekit/common/task_definitions.py deleted file mode 100644 index 56534fd7ea..0000000000 --- a/tests/flytekit/common/task_definitions.py +++ /dev/null @@ -1,9 +0,0 @@ -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types - - -@inputs(a=Types.Integer) -@outputs(b=Types.Integer) -@python_task -def add_one(wf_params, a, b): - b.set(a + 1) diff --git a/tests/flytekit/common/workflows/__init__.py b/tests/flytekit/common/workflows/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/common/workflows/batch.py b/tests/flytekit/common/workflows/batch.py deleted file mode 100644 index d3e095dd82..0000000000 --- a/tests/flytekit/common/workflows/batch.py +++ /dev/null @@ -1,128 +0,0 @@ -from six import moves as _six_moves - -from flytekit.sdk.tasks import dynamic_task, inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow_class - - -@outputs(out_ints=[Types.Integer]) -@dynamic_task -def sample_batch_task_sq(wf_params, out_ints): - wf_params.stats.incr("task_run") - res2 = [] - for i in _six_moves.range(0, 3): - task = sq_sub_task(in1=i) - yield task - res2.append(task.outputs.out1) - out_ints.set(res2) - - -@outputs(out_str=[Types.String], out_ints=[[Types.Integer]]) -@dynamic_task -def no_inputs_sample_batch_task(wf_params, out_str, out_ints): - wf_params.stats.incr("task_run") - res = ["I'm the first result"] - for i in _six_moves.range(0, 3): - task = sub_task(in1=i) - yield task - res.append(task.outputs.out1) - res.append("I'm after each sub-task result") - res.append("I'm the last result") - - res2 = [] - for i in _six_moves.range(0, 3): - task = int_sub_task(in1=i) - yield task - res2.append(task.outputs.out1) - - # Nested batch tasks - task = sample_batch_task_sq() - yield task - res2.append(task.outputs.out_ints) - - task = sample_batch_task_sq() - yield task - res2.append(task.outputs.out_ints) - - out_str.set(res) - out_ints.set(res2) - - -@inputs(in1=Types.Integer) -@outputs(out_str=[Types.String]) -@dynamic_task(cache=True, cache_version="1") -def sample_batch_task_beatles_cached(wf_params, in1, out_str): - wf_params.stats.incr("task_run") - res2 = [] - for i in _six_moves.range(0, in1): - task = sample_beatles_lyrics_cached(in1=i) - yield task - res2.append(task.outputs.out1) - out_str.set(res2) - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.String) -@python_task(cache=True, cache_version="1") -def sample_beatles_lyrics_cached(wf_params, in1, out1): - wf_params.stats.incr("task_run") - lyrics = ["Ob-La-Di, Ob-La-Da", "When I'm 64", "Yesterday"] - out1.set(lyrics[in1 % 3]) - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.String) -@python_task -def sub_task(wf_params, in1, out1): - wf_params.stats.incr("task_run") - out1.set("hello {}".format(in1)) - - -@inputs(in1=Types.Integer) -@outputs(out1=[Types.Integer]) -@python_task -def int_sub_task(wf_params, in1, out1): - wf_params.stats.incr("task_run") - out1.set([in1, in1 * 2, in1 * 3]) - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.Integer) -@python_task -def sq_sub_task(wf_params, in1, out1): - wf_params.stats.incr("task_run") - out1.set(in1 * in1) - - -@inputs(ints_to_print=[[Types.Integer]], strings_to_print=[Types.String]) -@python_task(cache_version="1") -def print_every_time(wf_params, ints_to_print, strings_to_print): - wf_params.stats.incr("task_run") - print("Expected Int values: {}".format([[0, 0, 0], [1, 2, 3], [2, 4, 6], [0, 1, 4], [0, 1, 4]])) - print("Actual Int values: {}".format(ints_to_print)) - - print( - "Expected String values: {}".format( - [ - u"I'm the first result", - u"hello 0", - u"I'm after each sub-task result", - u"hello 1", - u"I'm after each sub-task result", - u"hello 2", - u"I'm after each sub-task result", - u"I'm the last result", - ] - ) - ) - print("Actual String values: {}".format(strings_to_print)) - - -@workflow_class -class BatchTasksWorkflow(object): - num_subtasks = Input(Types.Integer, default=3) - task1 = no_inputs_sample_batch_task() - task2 = sample_batch_task_beatles_cached(in1=num_subtasks) - t = print_every_time(ints_to_print=task1.outputs.out_ints, strings_to_print=task1.outputs.out_str) - ints_out = Output(task1.outputs.out_ints, sdk_type=[[Types.Integer]]) - str_out = Output(task2.outputs.out_str, sdk_type=[Types.String]) diff --git a/tests/flytekit/common/workflows/dynamic_workflows.py b/tests/flytekit/common/workflows/dynamic_workflows.py deleted file mode 100644 index fabab6367b..0000000000 --- a/tests/flytekit/common/workflows/dynamic_workflows.py +++ /dev/null @@ -1,39 +0,0 @@ -from flytekit.sdk import tasks as _tasks -from flytekit.sdk import workflow as _workflow -from flytekit.sdk.types import Types as _Types -from flytekit.sdk.workflow import Input, Output, workflow_class - - -@_tasks.inputs(num=_Types.Integer) -@_tasks.outputs(out=_Types.Integer) -@_tasks.python_task -def inner_task(wf_params, num, out): - wf_params.logging.info("Running inner task... setting output to input") - out.set(num) - - -@_workflow.workflow_class() -class IdentityWorkflow(object): - a = _workflow.Input(_Types.Integer, default=5, help="Input for inner workflow") - odd_nums_task = inner_task(num=a) - task_output = _workflow.Output(odd_nums_task.outputs.out, sdk_type=_Types.Integer) - - -id_lp = IdentityWorkflow.create_launch_plan() - - -@_tasks.inputs(num=_Types.Integer) -@_tasks.outputs(out=_Types.Integer) -@_tasks.dynamic_task -def lp_yield_task(wf_params, num, out): - wf_params.logging.info("Running inner task... yielding a launchplan") - identity_lp_execution = id_lp(a=num) - yield identity_lp_execution - out.set(identity_lp_execution.outputs.task_output) - - -@workflow_class -class DynamicLaunchPlanCaller(object): - outer_a = Input(_Types.Integer, default=5, help="Input for inner workflow") - lp_task = lp_yield_task(num=outer_a) - wf_output = Output(lp_task.outputs.out, sdk_type=_Types.Integer) diff --git a/tests/flytekit/common/workflows/failing_workflows.py b/tests/flytekit/common/workflows/failing_workflows.py deleted file mode 100644 index 260e460b17..0000000000 --- a/tests/flytekit/common/workflows/failing_workflows.py +++ /dev/null @@ -1,28 +0,0 @@ -from flytekit.models.core.workflow import WorkflowMetadata -from flytekit.sdk.tasks import python_task -from flytekit.sdk.workflow import workflow_class - - -@python_task -def div_zero(wf_params): - return 5 / 0 - - -@python_task -def log_something(wf_params): - wf_params.logging.warn("Hello world") - - -@workflow_class(on_failure=WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) -class FailingWorkflowWithRunToCompletion(object): - """ - [start] -> [first_layer] -> [second_layer] -> [end] - \\_ [first_layer_2] _/ - """ - - first_layer = log_something() - first_layer_2 = div_zero() - second_layer = div_zero() - - # This forces second_layer node to run after first layer - first_layer >> second_layer diff --git a/tests/flytekit/common/workflows/gpu.py b/tests/flytekit/common/workflows/gpu.py deleted file mode 100644 index 7e7621ac69..0000000000 --- a/tests/flytekit/common/workflows/gpu.py +++ /dev/null @@ -1,20 +0,0 @@ -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow_class - - -@inputs(a=Types.Integer) -@outputs(b=Types.Integer) -@python_task(gpu_request="1", gpu_limit="1") -def add_one(wf_params, a, b): - # TODO lets add a test that works with tensorflow, but we need it to be in - # a different container - b.set(a + 1) - - -@workflow_class -class SimpleWorkflow(object): - input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help="Not required.") - a = add_one(a=input_1) - output = Output(a.outputs.b, sdk_type=Types.Integer) diff --git a/tests/flytekit/common/workflows/hive.py b/tests/flytekit/common/workflows/hive.py deleted file mode 100644 index b3cb56fe5d..0000000000 --- a/tests/flytekit/common/workflows/hive.py +++ /dev/null @@ -1,33 +0,0 @@ -import six as _six - -from flytekit.sdk.tasks import inputs, outputs, python_task, qubole_hive_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class - - -@outputs(hive_results=[Types.Schema()]) -@qubole_hive_task(tags=[_six.text_type("these"), _six.text_type("are"), _six.text_type("tags")]) -def generate_queries(wf_params, hive_results): - q1 = "SELECT 1" - q2 = "SELECT 'two'" - schema_1, formatted_query_1 = Types.Schema().create_from_hive_query(select_query=q1) - schema_2, formatted_query_2 = Types.Schema().create_from_hive_query(select_query=q2) - - hive_results.set([schema_1, schema_2]) - return [formatted_query_1, formatted_query_2] - - -@inputs(ss=[Types.Schema()]) -@python_task -def print_schemas(wf_params, ss): - for s in ss: - with s as r: - for df in r.iter_chunks(): - df = r.read() - print(df) - - -@workflow_class -class ExampleQueryWorkflow(object): - a = generate_queries() - b = print_schemas(ss=a.outputs.hive_results) diff --git a/tests/flytekit/common/workflows/nested.py b/tests/flytekit/common/workflows/nested.py deleted file mode 100644 index a364356ff6..0000000000 --- a/tests/flytekit/common/workflows/nested.py +++ /dev/null @@ -1,49 +0,0 @@ -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow_class - - -@inputs(a=Types.Integer) -@outputs(b=Types.Integer) -@python_task -def add_one(wf_params, a, b): - b.set(a + 1) - - -@inputs(a=Types.Integer) -@outputs(b=Types.Integer) -@python_task -def subtract_one(wf_params, a, b): - b.set(a - 1) - - -@inputs(a=Types.Integer, b=Types.Integer) -@outputs(c=Types.Integer) -@python_task -def sum(wf_params, a, b, c): - c.set(a + b) - - -@workflow_class -class Child(object): - input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help="Not required.") - a = add_one(a=input_1) - b = add_one(a=input_2) - c = add_one(a=100) - output = Output(c.outputs.b, sdk_type=Types.Integer) - - -# Create a simple launch plan without any overrides for inputs to the child workflow. -child_lp = Child.create_launch_plan() - - -# Create a parent workflow that invokes the child workflow previously declared. Note that it takes advantage of -# default and required inputs on different calls. -@workflow_class -class Parent(object): - input_1 = Input(Types.Integer) - child1 = child_lp(input_1=input_1) - child2 = child_lp(input_1=input_1, input_2=10) - final_sum = sum(a=child1.outputs.output, b=child2.outputs.output) - output = Output(final_sum.outputs.c, sdk_type=Types.Integer) diff --git a/tests/flytekit/common/workflows/notebook.py b/tests/flytekit/common/workflows/notebook.py deleted file mode 100644 index e8af3b7467..0000000000 --- a/tests/flytekit/common/workflows/notebook.py +++ /dev/null @@ -1,25 +0,0 @@ -from flytekit.contrib.notebook.tasks import python_notebook, spark_notebook -from flytekit.sdk.tasks import inputs, outputs -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - -interactive_python = python_notebook( - notebook_path="../../../../notebook-task-examples/python-notebook.ipynb", - inputs=inputs(pi=Types.Float), - outputs=outputs(out=Types.Float), - cpu_request="1", - memory_request="1G", -) - -interactive_spark = spark_notebook( - notebook_path="../../../../notebook-task-examples/spark-notebook-pi.ipynb", - inputs=inputs(partitions=Types.Integer), - outputs=outputs(pi=Types.Float), -) - - -@workflow_class -class FlyteNotebookSparkWorkflow(object): - partitions = Input(Types.Integer, default=10) - out1 = interactive_spark(partitions=partitions) - out2 = interactive_python(pi=out1.outputs.pi) diff --git a/tests/flytekit/common/workflows/notifications.py b/tests/flytekit/common/workflows/notifications.py deleted file mode 100644 index d09bc14fad..0000000000 --- a/tests/flytekit/common/workflows/notifications.py +++ /dev/null @@ -1,34 +0,0 @@ -from flytekit.common import notifications as _notifications -from flytekit.models.core import execution as _execution -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - - -@inputs(a=Types.Integer, b=Types.Integer) -@outputs(c=Types.Integer) -@python_task -def add_two_integers(wf_params, a, b, c): - c.set(a + b) - - -@workflow_class -class BasicWorkflow(object): - input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=1, help="Not required.") - a = add_two_integers(a=input_1, b=input_2) - - -notification_lp = BasicWorkflow.create_launch_plan( - notifications=[ - _notifications.Email( - [ - _execution.WorkflowExecutionPhase.SUCCEEDED, - _execution.WorkflowExecutionPhase.FAILED, - _execution.WorkflowExecutionPhase.TIMED_OUT, - _execution.WorkflowExecutionPhase.ABORTED, - ], - ["flyte-test-notifications@mydomain.com"], - ) - ] -) diff --git a/tests/flytekit/common/workflows/presto.py b/tests/flytekit/common/workflows/presto.py deleted file mode 100644 index e681a3e953..0000000000 --- a/tests/flytekit/common/workflows/presto.py +++ /dev/null @@ -1,25 +0,0 @@ -from flytekit.common.tasks.presto_task import SdkPrestoTask -from flytekit.sdk.tasks import inputs -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow_class - -schema = Types.Schema([("a", Types.String), ("b", Types.Integer)]) - -presto_task = SdkPrestoTask( - task_inputs=inputs(ds=Types.String, rg=Types.String), - statement="SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds}}' LIMIT 10", - output_schema=schema, - routing_group="{{ .Inputs.rg }}", - # catalog="hive", - # schema="city", -) - - -@workflow_class() -class PrestoWorkflow(object): - ds = Input(Types.String, required=True, help="Test string with no default") - # routing_group = Input(Types.String, required=True, help="Test string with no default") - - p_task = presto_task(ds=ds, rg="etl") - - output_a = Output(p_task.outputs.results, sdk_type=schema) diff --git a/tests/flytekit/common/workflows/python.py b/tests/flytekit/common/workflows/python.py deleted file mode 100644 index a0d423b86c..0000000000 --- a/tests/flytekit/common/workflows/python.py +++ /dev/null @@ -1,67 +0,0 @@ -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - - -@inputs(value_to_print=Types.Integer) -@outputs(out=Types.Integer) -@python_task(cache_version="1") -def add_one_and_print(workflow_parameters, value_to_print, out): - workflow_parameters.stats.incr("task_run") - added = value_to_print + 1 - print("My printed value: {}".format(added)) - out.set(added) - - -@inputs(value1_to_print=Types.Integer, value2_to_print=Types.Integer) -@outputs(out=Types.Integer) -@python_task(cache_version="1") -def sum_non_none(workflow_parameters, value1_to_print, value2_to_print, out): - workflow_parameters.stats.incr("task_run") - added = 0 - for value in [value1_to_print, value2_to_print]: - print("Adding values: {}".format(value)) - if value is not None: - added += value - added += 1 - print("My printed value: {}".format(added)) - out.set(added) - - -@inputs( - value1_to_add=Types.Integer, - value2_to_add=Types.Integer, - value3_to_add=Types.Integer, - value4_to_add=Types.Integer, -) -@outputs(out=Types.Integer) -@python_task(cache_version="1") -def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, value3_to_add, value4_to_add, out): - workflow_parameters.stats.incr("task_run") - summed = sum([value1_to_add, value2_to_add, value3_to_add, value4_to_add]) - print("Summed up to: {}".format(summed)) - out.set(summed) - - -@inputs(value_to_print=Types.Integer, date_triggered=Types.Datetime) -@python_task(cache_version="1") -def print_every_time(workflow_parameters, value_to_print, date_triggered): - workflow_parameters.stats.incr("task_run") - print("My printed value: {} @ {}".format(value_to_print, date_triggered)) - - -@workflow_class -class PythonTasksWorkflow(object): - triggered_date = Input(Types.Datetime) - print1a = add_one_and_print(value_to_print=3) - print1b = add_one_and_print(value_to_print=101) - print2 = sum_non_none(value1_to_print=print1a.outputs.out, value2_to_print=print1b.outputs.out) - print3 = add_one_and_print(value_to_print=print2.outputs.out) - print4 = add_one_and_print(value_to_print=print3.outputs.out) - print_sum = sum_and_print( - value1_to_add=print2.outputs.out, - value2_to_add=print3.outputs.out, - value3_to_add=print4.outputs.out, - value4_to_add=100, - ) - print_always = print_every_time(value_to_print=print_sum.outputs.out, date_triggered=triggered_date) diff --git a/tests/flytekit/common/workflows/raw_container.py b/tests/flytekit/common/workflows/raw_container.py deleted file mode 100644 index 60edbcc737..0000000000 --- a/tests/flytekit/common/workflows/raw_container.py +++ /dev/null @@ -1,31 +0,0 @@ -from flytekit.common.tasks.raw_container import SdkRawContainerTask -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow_class - -square = SdkRawContainerTask( - input_data_dir="/var/inputs", - output_data_dir="/var/outputs", - inputs={"val": Types.Integer}, - outputs={"out": Types.Integer}, - image="alpine", - command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"], -) - -sum = SdkRawContainerTask( - input_data_dir="/var/flyte/inputs", - output_data_dir="/var/flyte/outputs", - inputs={"x": Types.Integer, "y": Types.Integer}, - outputs={"out": Types.Integer}, - image="alpine", - command=["sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"], -) - - -@workflow_class -class RawContainerWorkflow(object): - val1 = Input(Types.Integer) - val2 = Input(Types.Integer) - sq1 = square(val=val1) - sq2 = square(val=val2) - sm = sum(x=sq1.outputs.out, y=sq2.outputs.out) - sum_of_squares = Output(sm.outputs.out, sdk_type=Types.Integer) diff --git a/tests/flytekit/common/workflows/raw_edge_detector.py b/tests/flytekit/common/workflows/raw_edge_detector.py deleted file mode 100644 index eeefac4de1..0000000000 --- a/tests/flytekit/common/workflows/raw_edge_detector.py +++ /dev/null @@ -1,20 +0,0 @@ -from flytekit.common.tasks.raw_container import SdkRawContainerTask -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow_class - -edges = SdkRawContainerTask( - input_data_dir="/inputs", - output_data_dir="/outputs", - inputs={"image": Types.Blob, "script": Types.Blob}, - outputs={"edges": Types.Blob}, - image="jjanzic/docker-python3-opencv", - command=["python", "{{.inputs.script}}", "/inputs/image", "/outputs/edges"], -) - - -@workflow_class -class EdgeDetector(object): - script = Input(Types.Blob) - image = Input(Types.Blob) - edge_task = edges(script=script, image=image) - out = Output(edge_task.outputs.edges, sdk_type=Types.Blob) diff --git a/tests/flytekit/common/workflows/scala_spark.py b/tests/flytekit/common/workflows/scala_spark.py deleted file mode 100644 index 7df3b0fb32..0000000000 --- a/tests/flytekit/common/workflows/scala_spark.py +++ /dev/null @@ -1,32 +0,0 @@ -from flytekit.sdk.spark_types import SparkType -from flytekit.sdk.tasks import generic_spark_task, inputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - -scala_spark = generic_spark_task( - spark_type=SparkType.SCALA, - inputs=inputs(partitions=Types.Integer), - main_class="org.apache.spark.examples.SparkPi", - main_application_file="local:///opt/spark/examples/jars/spark-examples.jar", - spark_conf={ - "spark.driver.memory": "1000M", - "spark.executor.memory": "1000M", - "spark.executor.cores": "1", - "spark.executor.instances": "2", - }, - cache_version="1", -) - - -@inputs(date_triggered=Types.Datetime) -@python_task(cache_version="1") -def print_every_time(workflow_parameters, date_triggered): - print("My input : {}".format(date_triggered)) - - -@workflow_class -class SparkTasksWorkflow(object): - triggered_date = Input(Types.Datetime) - partitions = Input(Types.Integer) - spark_task = scala_spark(partitions=partitions) - print_always = print_every_time(date_triggered=triggered_date) diff --git a/tests/flytekit/common/workflows/sidecar.py b/tests/flytekit/common/workflows/sidecar.py deleted file mode 100644 index ff0d4e7484..0000000000 --- a/tests/flytekit/common/workflows/sidecar.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -import time - -from k8s.io.api.core.v1 import generated_pb2 - -from flytekit.sdk.tasks import sidecar_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - - -def generate_pod_spec_for_task(): - pod_spec = generated_pb2.PodSpec() - secondary_container = generated_pb2.Container( - name="secondary", - image="alpine", - ) - secondary_container.command.extend(["/bin/sh"]) - secondary_container.args.extend(["-c", "echo hi sidecar world > /data/message.txt"]) - shared_volume_mount = generated_pb2.VolumeMount( - name="shared-data", - mountPath="/data", - ) - secondary_container.volumeMounts.extend([shared_volume_mount]) - - primary_container = generated_pb2.Container(name="primary") - primary_container.volumeMounts.extend([shared_volume_mount]) - - pod_spec.volumes.extend( - [ - generated_pb2.Volume( - name="shared-data", - volumeSource=generated_pb2.VolumeSource( - emptyDir=generated_pb2.EmptyDirVolumeSource( - medium="Memory", - ) - ), - ) - ] - ) - pod_spec.containers.extend([primary_container, secondary_container]) - return pod_spec - - -@sidecar_task( - pod_spec=generate_pod_spec_for_task(), - primary_container_name="primary", -) -def a_sidecar_task(wfparams): - while not os.path.isfile("/data/message.txt"): - time.sleep(5) - - -@workflow_class -class SimpleSidecarWorkflow(object): - input_1 = Input(Types.String) - my_sidecar_task = a_sidecar_task() diff --git a/tests/flytekit/common/workflows/simple.py b/tests/flytekit/common/workflows/simple.py deleted file mode 100644 index f264fe39bf..0000000000 --- a/tests/flytekit/common/workflows/simple.py +++ /dev/null @@ -1,114 +0,0 @@ -import pandas as _pd - -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - - -@inputs(a=Types.Integer) -@outputs(b=Types.Integer) -@python_task -def add_one(wf_params, a, b): - b.set(a + 1) - - -@inputs(a=Types.Integer) -@outputs(b=Types.Integer) -@python_task(cache=True, cache_version="1") -def subtract_one(wf_params, a, b): - b.set(a - 1) - - -@outputs( - a=Types.Blob, - b=Types.CSV, - c=Types.MultiPartCSV, - d=Types.MultiPartBlob, - e=Types.Schema([("a", Types.Integer), ("b", Types.Integer)]), -) -@python_task -def write_special_types(wf_params, a, b, c, d, e): - blob = Types.Blob() - with blob as w: - w.write("hello I'm a blob".encode("utf-8")) - - csv = Types.CSV() - with csv as w: - w.write("hello,i,iz,blob") - - mpcsv = Types.MultiPartCSV() - with mpcsv.create_part("000000") as w: - w.write("hello,i,iz,blob") - with mpcsv.create_part("000001") as w: - w.write("hello,i,iz,blob2") - - mpblob = Types.MultiPartBlob() - with mpblob.create_part("000000") as w: - w.write("hello I'm a mp blob".encode("utf-8")) - with mpblob.create_part("000001") as w: - w.write("hello I'm a mp blob too".encode("utf-8")) - - schema = Types.Schema([("a", Types.Integer), ("b", Types.Integer)])() - with schema as w: - w.write(_pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})) - w.write(_pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6, 5, 4]})) - - a.set(blob) - b.set(csv) - c.set(mpcsv) - d.set(mpblob) - e.set(schema) - - -@inputs( - a=Types.Blob, - b=Types.CSV, - c=Types.MultiPartCSV, - d=Types.MultiPartBlob, - e=Types.Schema([("a", Types.Integer), ("b", Types.Integer)]), -) -@python_task -def read_special_types(wf_params, a, b, c, d, e): - with a as r: - assert r.read().decode("utf-8") == "hello I'm a blob" - - with b as r: - assert r.read() == "hello,i,iz,blob" - - with c as r: - assert len(r) == 2 - assert r[0].read() == "hello,i,iz,blob" - assert r[1].read() == "hello,i,iz,blob2" - - with d as r: - assert len(r) == 2 - assert r[0].read().decode("utf-8") == "hello I'm a mp blob" - assert r[1].read().decode("utf-8") == "hello I'm a mp blob too" - - with e as r: - df = r.read() - assert df["a"].tolist() == [1, 2, 3] - assert df["b"].tolist() == [4, 5, 6] - - df = r.read() - assert df["a"].tolist() == [3, 2, 1] - assert df["b"].tolist() == [6, 5, 4] - assert r.read() is None - - -@workflow_class -class SimpleWorkflow(object): - input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help="Not required.") - a = add_one(a=input_1) - b = add_one(a=input_2) - c = subtract_one(a=input_1) - - d = write_special_types() - e = read_special_types( - a=d.outputs.a, - b=d.outputs.b, - c=d.outputs.c, - d=d.outputs.d, - e=d.outputs.e, - ) diff --git a/tests/flytekit/common/workflows/spark.py b/tests/flytekit/common/workflows/spark.py deleted file mode 100644 index b3f381f6ba..0000000000 --- a/tests/flytekit/common/workflows/spark.py +++ /dev/null @@ -1,50 +0,0 @@ -import random -from operator import add - -from six.moves import range - -from flytekit.sdk.tasks import inputs, outputs, python_task, spark_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - - -@inputs(partitions=Types.Integer) -@outputs(out=Types.Float) -@spark_task( - spark_conf={ - "spark.driver.memory": "1000M", - "spark.executor.memory": "1000M", - "spark.executor.cores": "1", - "spark.executor.instances": "2", - "spark.hadoop.mapred.output.committer.class": "org.apache.hadoop.mapred.DirectFileOutputCommitter", - "spark.hadoop.mapreduce.use.directfileoutputcommitter": "true", - }, - cache_version="1", -) -def hello_spark(workflow_parameters, spark_context, partitions, out): - print("Starting Spark with Partitions: {}".format(partitions)) - - n = 100000 * partitions - count = spark_context.parallelize(range(1, n + 1), partitions).map(f).reduce(add) - pi_val = 4.0 * count / n - print("Pi val is :{}".format(pi_val)) - out.set(pi_val) - - -@inputs(value_to_print=Types.Float, date_triggered=Types.Datetime) -@python_task(cache_version="1") -def print_every_time(workflow_parameters, value_to_print, date_triggered): - print("My printed value: {} @ {}".format(value_to_print, date_triggered)) - - -def f(_): - x = random.random() * 2 - 1 - y = random.random() * 2 - 1 - return 1 if x ** 2 + y ** 2 <= 1 else 0 - - -@workflow_class -class SparkTasksWorkflow(object): - triggered_date = Input(Types.Datetime) - sparkTask = hello_spark(partitions=50) - print_always = print_every_time(value_to_print=sparkTask.outputs.out, date_triggered=triggered_date) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index a45c037279..7784d76d82 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -9,8 +9,8 @@ import pytest from flytekit import kwtypes -from flytekit.common.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.core.launch_plan import LaunchPlan +from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.remote.remote import FlyteRemote from flytekit.types.schema import FlyteSchema diff --git a/tests/flytekit/loadtests/__init__.py b/tests/flytekit/loadtests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/loadtests/cp_orchestrator.py b/tests/flytekit/loadtests/cp_orchestrator.py deleted file mode 100644 index 776afecd59..0000000000 --- a/tests/flytekit/loadtests/cp_orchestrator.py +++ /dev/null @@ -1,29 +0,0 @@ -from six.moves import range - -from flytekit.sdk.workflow import workflow_class -from tests.flytekit.loadtests.cp_python import FlyteCPPythonLoadTestWorkflow -from tests.flytekit.loadtests.cp_spark import FlyteCPSparkLoadTestWorkflow - -# launch plans for individual load tests. -python_loadtest_lp = FlyteCPPythonLoadTestWorkflow.create_launch_plan() -spark_loadtest_lp = FlyteCPSparkLoadTestWorkflow.create_launch_plan() - - -# Orchestrator workflow invokes the individual load test workflows (hive, python, spark). Its static for now but we -# will make it dynamic in future -@workflow_class -class CPLoadTestOrchestrationWorkflow(object): - - # python load tests. 5 tasks each. Total: 1 cpu 5gb memory per python workflow. - python_task_count = 50 - p = [None] * python_task_count - - for i in range(0, python_task_count): - p[i] = python_loadtest_lp() - - # spark load tests. - spark_task_count = 30 - s = [None] * spark_task_count - - for i in range(0, spark_task_count): - s[i] = spark_loadtest_lp() diff --git a/tests/flytekit/loadtests/cp_python.py b/tests/flytekit/loadtests/cp_python.py deleted file mode 100644 index f26cbb51f7..0000000000 --- a/tests/flytekit/loadtests/cp_python.py +++ /dev/null @@ -1,28 +0,0 @@ -import time - -from six.moves import range - -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class - - -@inputs(value1_to_add=Types.Integer, value2_to_add=Types.Integer) -@outputs(out=Types.Integer) -@python_task(cpu_request="1", cpu_limit="1", memory_request="3G") -def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, out): - for i in range(5 * 60): - print("This is load test task. I have been running for {} seconds.".format(i)) - time.sleep(1) - - summed = sum([value1_to_add, value2_to_add]) - print("Summed up to: {}".format(summed)) - out.set(summed) - - -@workflow_class -class FlyteCPPythonLoadTestWorkflow(object): - - print_sum = [None] * 5 - for i in range(0, 5): - print_sum[i] = sum_and_print(value1_to_add=1, value2_to_add=1) diff --git a/tests/flytekit/loadtests/cp_spark.py b/tests/flytekit/loadtests/cp_spark.py deleted file mode 100644 index b79fbc3509..0000000000 --- a/tests/flytekit/loadtests/cp_spark.py +++ /dev/null @@ -1,42 +0,0 @@ -import random -from operator import add - -from six.moves import range - -from flytekit.sdk.tasks import inputs, outputs, spark_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class - - -@inputs(partitions=Types.Integer) -@outputs(out=Types.Float) -@spark_task( - spark_conf={ - "spark.driver.memory": "600M", - "spark.executor.memory": "600M", - "spark.executor.cores": "1", - "spark.executor.instances": "1", - "spark.hadoop.mapred.output.committer.class": "org.apache.hadoop.mapred.DirectFileOutputCommitter", - "spark.hadoop.mapreduce.use.directfileoutputcommitter": "true", - }, - cache_version="1", -) -def hello_spark(workflow_parameters, spark_context, partitions, out): - print("Starting Spark with Partitions: {}".format(partitions)) - - n = 30000 * partitions - count = spark_context.parallelize(range(1, n + 1), partitions).map(f).reduce(add) - pi_val = 4.0 * count / n - print("Pi val is :{}".format(pi_val)) - out.set(pi_val) - - -def f(_): - x = random.random() * 2 - 1 - y = random.random() * 2 - 1 - return 1 if x ** 2 + y ** 2 <= 1 else 0 - - -@workflow_class -class FlyteCPSparkLoadTestWorkflow(object): - sparkTask = hello_spark(partitions=50) diff --git a/tests/flytekit/loadtests/dynamic_job.py b/tests/flytekit/loadtests/dynamic_job.py deleted file mode 100644 index 520ff401a7..0000000000 --- a/tests/flytekit/loadtests/dynamic_job.py +++ /dev/null @@ -1,40 +0,0 @@ -import time - -from six.moves import range - -from flytekit.sdk.tasks import dynamic_task, inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - - -@inputs(value1=Types.Integer) -@outputs(out=Types.Integer) -@python_task(cpu_request="1", cpu_limit="1", memory_request="5G") -def dynamic_sub_task(workflow_parameters, value1, out): - for i in range(11 * 60): - print("This is load test task. I have been running for {} seconds.".format(i)) - time.sleep(1) - - output = value1 * 2 - print("Output: {}".format(output)) - out.set(output) - - -@inputs(tasks_count=Types.Integer) -@outputs(out=[Types.Integer]) -@dynamic_task(cache_version="1") -def dynamic_task(workflow_parameters, tasks_count, out): - res = [] - for i in range(0, tasks_count): - task = dynamic_sub_task(value1=i) - yield task - res.append(task.outputs.out) - - # Define how to set the final result of the task - out.set(res) - - -@workflow_class -class FlyteDJOLoadTestWorkflow(object): - tasks_count = Input(Types.Integer) - dj = dynamic_task(tasks_count=tasks_count) diff --git a/tests/flytekit/loadtests/orchestrator.py b/tests/flytekit/loadtests/orchestrator.py deleted file mode 100644 index 5981502319..0000000000 --- a/tests/flytekit/loadtests/orchestrator.py +++ /dev/null @@ -1,48 +0,0 @@ -from six.moves import range - -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class -from tests.flytekit.loadtests.dynamic_job import FlyteDJOLoadTestWorkflow -from tests.flytekit.loadtests.hive import FlyteHiveLoadTestWorkflow -from tests.flytekit.loadtests.python import FlytePythonLoadTestWorkflow -from tests.flytekit.loadtests.spark import FlyteSparkLoadTestWorkflow - -# launch plans for individual load tests. -python_loadtest_lp = FlytePythonLoadTestWorkflow.create_launch_plan() -hive_loadtest_lp = FlyteHiveLoadTestWorkflow.create_launch_plan() -spark_loadtest_lp = FlyteSparkLoadTestWorkflow.create_launch_plan() -dynamic_job_loadtest_lp = FlyteDJOLoadTestWorkflow.create_launch_plan() - - -# Orchestrator workflow invokes the individual load test workflows (hive, python, spark). Its static for now but we -# will make it dynamic in future,. -@workflow_class -class LoadTestOrchestrationWorkflow(object): - - # 30 python tasks ~= 75 i3.16x nodes on AWS Batch - python_task_count = 30 - # 30 spark tasks ~= 60 i3.16x nodes on AWS Batch - spark_task_count = 30 - # 3 dynamic-jobs each of 1000 tasks ~= 3*20 i3.16x nodes on AWS Batch - djo_task_count = 1000 - dj_count = 3 - - p = [None] * python_task_count - s = [None] * spark_task_count - d = [None] * dj_count - - # python tasks - for i in range(0, python_task_count): - p[i] = python_loadtest_lp() - - # dynamic-job tasks - for i in range(0, dj_count): - d[i] = dynamic_job_loadtest_lp(tasks_count=djo_task_count) - - # hive load tests. - # h1 = hive_loadtest_lp() - - # spark load tests - trigger_time = Input(Types.Datetime) - for i in range(0, spark_task_count): - s[i] = spark_loadtest_lp(triggered_date=trigger_time, offset=i) diff --git a/tests/flytekit/loadtests/python.py b/tests/flytekit/loadtests/python.py deleted file mode 100644 index e6da8df722..0000000000 --- a/tests/flytekit/loadtests/python.py +++ /dev/null @@ -1,27 +0,0 @@ -import time - -from six.moves import range - -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class - - -@inputs(value1_to_add=Types.Integer, value2_to_add=Types.Integer) -@outputs(out=Types.Integer) -@python_task(cpu_request="5", cpu_limit="5", memory_request="32G") -def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, out): - for i in range(11 * 60): - print("This is load test task. I have been running for {} seconds.".format(i)) - time.sleep(1) - - summed = sum([value1_to_add, value2_to_add]) - print("Summed up to: {}".format(summed)) - out.set(summed) - - -@workflow_class -class FlytePythonLoadTestWorkflow(object): - print_sum = [None] * 30 - for i in range(0, 30): - print_sum[i] = sum_and_print(value1_to_add=1, value2_to_add=1) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index d54d7dc52b..3a71b567d8 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,194 +1,30 @@ -import os import typing from collections import OrderedDict import mock import pytest -import six -from click.testing import CliRunner -from flyteidl.core import literals_pb2 as _literals_pb2 from flyteidl.core.errors_pb2 import ErrorDocument -from flytekit.bin.entrypoint import _dispatch_execute, _legacy_execute_task, execute_task_cmd, setup_execution -from flytekit.common import constants as _constants -from flytekit.common import utils as _utils -from flytekit.common.exceptions import user as user_exceptions -from flytekit.common.exceptions.scopes import system_entry_point -from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration +from flytekit.bin.entrypoint import _dispatch_execute, setup_execution from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions import user as user_exceptions +from flytekit.exceptions.scopes import system_entry_point from flytekit.extras.persistence.gcs_gsutil import GCSPersistence from flytekit.extras.persistence.s3_awscli import S3Persistence from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models from flytekit.models.core import execution as execution_models -from tests.flytekit.common import task_definitions as _task_defs - - -def _type_map_from_variable_map(variable_map): - return {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in six.iteritems(variable_map)} - - -def test_single_step_entrypoint_in_proc(): - with _TemporaryConfiguration( - os.path.join(os.path.dirname(__file__), "fake.config"), - internal_overrides={"project": "test", "domain": "development"}, - ): - with _utils.AutoDeletingTempDir("in") as input_dir: - literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, - _type_map_from_variable_map(_task_defs.add_one.interface.inputs), - ) - input_file = os.path.join(input_dir.name, "inputs.pb") - _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) - - with _utils.AutoDeletingTempDir("out") as output_dir: - _legacy_execute_task( - _task_defs.add_one.task_module, - _task_defs.add_one.task_function_name, - input_file, - output_dir.name, - output_dir.name, - False, - ) - - p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, - os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), - ) - raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( - _literal_models.LiteralMap.from_flyte_idl(p), - _type_map_from_variable_map(_task_defs.add_one.interface.outputs), - ) - assert raw_map["b"] == 10 - assert len(raw_map) == 1 - - -def test_single_step_entrypoint_out_of_proc(): - with _TemporaryConfiguration( - os.path.join(os.path.dirname(__file__), "fake.config"), - internal_overrides={"project": "test", "domain": "development"}, - ): - with _utils.AutoDeletingTempDir("in") as input_dir: - literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, - _type_map_from_variable_map(_task_defs.add_one.interface.inputs), - ) - input_file = os.path.join(input_dir.name, "inputs.pb") - _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) - - with _utils.AutoDeletingTempDir("out") as output_dir: - cmd = [] - cmd.extend(["--task-module", _task_defs.add_one.task_module]) - cmd.extend(["--task-name", _task_defs.add_one.task_function_name]) - cmd.extend(["--inputs", input_file]) - cmd.extend(["--output-prefix", output_dir.name]) - result = CliRunner().invoke(execute_task_cmd, cmd) - - assert result.exit_code == 0 - p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, - os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), - ) - raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( - _literal_models.LiteralMap.from_flyte_idl(p), - _type_map_from_variable_map(_task_defs.add_one.interface.outputs), - ) - assert raw_map["b"] == 10 - assert len(raw_map) == 1 - - -def test_arrayjob_entrypoint_in_proc(): - with _TemporaryConfiguration( - os.path.join(os.path.dirname(__file__), "fake.config"), - internal_overrides={"project": "test", "domain": "development"}, - ): - with _utils.AutoDeletingTempDir("dir") as dir: - literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, - _type_map_from_variable_map(_task_defs.add_one.interface.inputs), - ) - - input_dir = os.path.join(dir.name, "1") - os.mkdir(input_dir) # auto cleanup will take this subdir into account - - input_file = os.path.join(input_dir, "inputs.pb") - _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) - - # construct indexlookup.pb which has array: [1] - mapped_index = _literal_models.Literal( - _literal_models.Scalar(primitive=_literal_models.Primitive(integer=1)) - ) - index_lookup_collection = _literal_models.LiteralCollection([mapped_index]) - index_lookup_file = os.path.join(dir.name, "indexlookup.pb") - _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(), index_lookup_file) - - # fake arrayjob task by setting environment variables - orig_env_index_var_name = os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME") - orig_env_array_index = os.environ.get("AWS_BATCH_JOB_ARRAY_INDEX") - os.environ["BATCH_JOB_ARRAY_INDEX_VAR_NAME"] = "AWS_BATCH_JOB_ARRAY_INDEX" - os.environ["AWS_BATCH_JOB_ARRAY_INDEX"] = "0" - - _legacy_execute_task( - _task_defs.add_one.task_module, - _task_defs.add_one.task_function_name, - dir.name, - dir.name, - dir.name, - False, - ) - - raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( - _literal_models.LiteralMap.from_flyte_idl( - _utils.load_proto_from_file( - _literals_pb2.LiteralMap, - os.path.join(input_dir, _constants.OUTPUT_FILE_NAME), - ) - ), - _type_map_from_variable_map(_task_defs.add_one.interface.outputs), - ) - assert raw_map["b"] == 10 - assert len(raw_map) == 1 - - # reset the env vars - if orig_env_index_var_name: - os.environ["BATCH_JOB_ARRAY_INDEX_VAR_NAME"] = orig_env_index_var_name - if orig_env_array_index: - os.environ["AWS_BATCH_JOB_ARRAY_INDEX"] = orig_env_array_index - - -@mock.patch("flytekit.bin.entrypoint._legacy_execute_task") -def test_backwards_compatible_replacement(mock_legacy_execute_task): - def return_args(*args, **kwargs): - assert args[4] is None - - mock_legacy_execute_task.side_effect = return_args - - with _TemporaryConfiguration( - os.path.join(os.path.dirname(__file__), "fake.config"), - internal_overrides={"project": "test", "domain": "development"}, - ): - with _utils.AutoDeletingTempDir("in"): - with _utils.AutoDeletingTempDir("out"): - cmd = [] - cmd.extend(["--task-module", "fake"]) - cmd.extend(["--task-name", "fake"]) - cmd.extend(["--inputs", "fake"]) - cmd.extend(["--output-prefix", "fake"]) - cmd.extend(["--raw-output-data-prefix", "{{.rawOutputDataPrefix}}"]) - result = CliRunner().invoke(execute_task_cmd, cmd) - assert result.exit_code == 0 - - -@mock.patch("flytekit.common.utils.load_proto_from_file") + + +@mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.common.utils.write_proto_to_file") +@mock.patch("flytekit.core.utils.write_proto_to_file") def test_dispatch_execute_void(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True @@ -214,10 +50,10 @@ def verify_output(*args, **kwargs): assert mock_write_to_file.call_count == 1 -@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.common.utils.write_proto_to_file") +@mock.patch("flytekit.core.utils.write_proto_to_file") def test_dispatch_execute_ignore(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True @@ -243,10 +79,10 @@ def test_dispatch_execute_ignore(mock_write_to_file, mock_upload_dir, mock_get_d assert mock_write_to_file.call_count == 0 -@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.common.utils.write_proto_to_file") +@mock.patch("flytekit.core.utils.write_proto_to_file") def test_dispatch_execute_exception(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True @@ -273,7 +109,7 @@ def verify_output(*args, **kwargs): # This function collects outputs instead of writing them to a file. -# See flytekit.common.utils.write_proto_to_file for the original +# See flytekit.core.utils.write_proto_to_file for the original def get_output_collector(results: OrderedDict): def output_collector(proto, path): results[path] = proto @@ -281,10 +117,10 @@ def output_collector(proto, path): return output_collector -@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.common.utils.write_proto_to_file") +@mock.patch("flytekit.core.utils.write_proto_to_file") def test_dispatch_execute_normal(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True @@ -318,10 +154,10 @@ def t1(a: int) -> str: assert lm.literals["o0"].scalar.primitive.string_value == "string is: 5" -@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.common.utils.write_proto_to_file") +@mock.patch("flytekit.core.utils.write_proto_to_file") def test_dispatch_execute_user_error_non_recov(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True @@ -358,10 +194,10 @@ def t1(a: int) -> str: assert ed.error.origin == execution_models.ExecutionError.ErrorKind.USER -@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.common.utils.write_proto_to_file") +@mock.patch("flytekit.core.utils.write_proto_to_file") def test_dispatch_execute_user_error_recoverable(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True @@ -402,10 +238,10 @@ def my_subwf(a: int) -> typing.List[str]: assert ed.error.origin == execution_models.ExecutionError.ErrorKind.USER -@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.common.utils.write_proto_to_file") +@mock.patch("flytekit.core.utils.write_proto_to_file") def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True diff --git a/tests/flytekit/unit/cli/pyflyte/test_launch_plans.py b/tests/flytekit/unit/cli/pyflyte/test_launch_plans.py deleted file mode 100644 index e19fea82a0..0000000000 --- a/tests/flytekit/unit/cli/pyflyte/test_launch_plans.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from flytekit.clis.sdk_in_container import launch_plan -from flytekit.clis.sdk_in_container.launch_plan import launch_plans - - -def test_list_commands(mock_ctx): - g = launch_plan.LaunchPlanExecuteGroup("test_group") - v = g.list_commands(mock_ctx) - assert v == ["common.workflows.simple.SimpleWorkflow"] - - -def test_get_commands(mock_ctx): - g = launch_plan.LaunchPlanExecuteGroup("test_group") - v = g.get_command(mock_ctx, "common.workflows.simple.SimpleWorkflow") - assert v.params[0].human_readable_name == "input_1" - assert "INTEGER" in v.params[0].help - assert v.params[1].human_readable_name == "input_2" - assert "INTEGER" in v.params[1].help - assert "Not required." in v.params[1].help - - with pytest.raises(Exception): - g.get_command(mock_ctx, "common.workflows.simple.DoesNotExist") - with pytest.raises(Exception): - g.get_command(mock_ctx, "does.not.exist") - - -def test_launch_plans_commands(mock_ctx): - command_names = [c for c in launch_plans.list_commands(mock_ctx)] - assert command_names == sorted(["execute", "activate-all", "activate-all-schedules"]) diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index e14be7fc5b..ced8c716c5 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -7,8 +7,8 @@ import flytekit from flytekit.clis.sdk_in_container import package, pyflyte, serialize -from flytekit.common.exceptions.user import FlyteValidationException from flytekit.core import context_manager +from flytekit.exceptions.user import FlyteValidationException def test_validate_image(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py deleted file mode 100644 index d1a6eed495..0000000000 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ /dev/null @@ -1,40 +0,0 @@ -from mock import MagicMock - -from flytekit.common.launch_plan import SdkLaunchPlan -from flytekit.common.tasks.task import SdkTask -from flytekit.common.workflow import SdkWorkflow - - -def test_register_workflows(mock_clirunner, monkeypatch): - - mock_register_task = MagicMock(return_value=MagicMock()) - monkeypatch.setattr(SdkTask, "register", mock_register_task) - mock_register_workflow = MagicMock(return_value=MagicMock()) - monkeypatch.setattr(SdkWorkflow, "register", mock_register_workflow) - mock_register_launch_plan = MagicMock(return_value=MagicMock()) - monkeypatch.setattr(SdkLaunchPlan, "register", mock_register_launch_plan) - - result = mock_clirunner("register", "-p", "project", "-d", "development", "-v", "--version", "workflows") - - assert result.exit_code == 0 - - assert len(mock_register_task.mock_calls) == 4 - assert len(mock_register_workflow.mock_calls) == 1 - assert len(mock_register_launch_plan.mock_calls) == 1 - - -def test_register_workflows_with_test_switch(mock_clirunner, monkeypatch): - mock_register_task = MagicMock(return_value=MagicMock()) - monkeypatch.setattr(SdkTask, "register", mock_register_task) - mock_register_workflow = MagicMock(return_value=MagicMock()) - monkeypatch.setattr(SdkWorkflow, "register", mock_register_workflow) - mock_register_launch_plan = MagicMock(return_value=MagicMock()) - monkeypatch.setattr(SdkLaunchPlan, "register", mock_register_launch_plan) - - result = mock_clirunner("register", "-p", "project", "-d", "development", "-v", "--version", "--test", "workflows") - - assert result.exit_code == 0 - - assert len(mock_register_task.mock_calls) == 0 - assert len(mock_register_workflow.mock_calls) == 0 - assert len(mock_register_launch_plan.mock_calls) == 0 diff --git a/tests/flytekit/unit/cli/test_cli_helpers.py b/tests/flytekit/unit/cli/test_cli_helpers.py index 9188b2fc3a..3ed08848c3 100644 --- a/tests/flytekit/unit/cli/test_cli_helpers.py +++ b/tests/flytekit/unit/cli/test_cli_helpers.py @@ -2,15 +2,12 @@ import flyteidl.admin.task_pb2 as _task_pb2 import flyteidl.admin.workflow_pb2 as _workflow_pb2 import flyteidl.core.tasks_pb2 as _core_task_pb2 -import pytest from flyteidl.core import identifier_pb2 as _identifier_pb2 from flyteidl.core import workflow_pb2 as _core_workflow_pb2 from flyteidl.core.identifier_pb2 import LAUNCH_PLAN from flytekit.clis import helpers from flytekit.clis.helpers import _hydrate_identifier, _hydrate_workflow_template_nodes, hydrate_registration_parameters -from flytekit.models import literals, types -from flytekit.models.interface import Parameter, ParameterMap, Variable def test_parse_args_into_dict(): @@ -28,36 +25,6 @@ def test_parse_args_into_dict(): assert output == {} -def test_construct_literal_map_from_variable_map(): - v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") - variable_map = { - "inputa": v, - } - - input_txt_dictionary = {"inputa": "15"} - - literal_map = helpers.construct_literal_map_from_variable_map(variable_map, input_txt_dictionary) - parsed_literal = literal_map.literals["inputa"].value - ll = literals.Scalar(primitive=literals.Primitive(integer=15)) - assert parsed_literal == ll - - -def test_construct_literal_map_from_parameter_map(): - v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") - p = Parameter(var=v, required=True) - pm = ParameterMap(parameters={"inputa": p}) - - input_txt_dictionary = {"inputa": "15"} - - literal_map = helpers.construct_literal_map_from_parameter_map(pm, input_txt_dictionary) - parsed_literal = literal_map.literals["inputa"].value - ll = literals.Scalar(primitive=literals.Primitive(integer=15)) - assert parsed_literal == ll - - with pytest.raises(Exception): - helpers.construct_literal_map_from_parameter_map(pm, {}) - - def test_strtobool(): assert not helpers.str2bool("False") assert not helpers.str2bool("OFF") diff --git a/tests/flytekit/unit/cli/test_flyte_cli.py b/tests/flytekit/unit/cli/test_flyte_cli.py index 3049590036..7c5830b728 100644 --- a/tests/flytekit/unit/cli/test_flyte_cli.py +++ b/tests/flytekit/unit/cli/test_flyte_cli.py @@ -4,54 +4,17 @@ from click.testing import CliRunner as _CliRunner from flytekit.clis.flyte_cli import main as _main -from flytekit.common.exceptions.user import FlyteAssertion -from flytekit.common.types import primitives -from flytekit.configuration import TemporaryConfiguration +from flytekit.exceptions.user import FlyteAssertion from flytekit.models import filters as _filters from flytekit.models.admin import common as _admin_common from flytekit.models.core import identifier as _core_identifier from flytekit.models.project import Project as _Project -from flytekit.sdk.tasks import inputs, outputs, python_task mm = _mock.MagicMock() mm.return_value = 100 -def get_sample_task(): - """ - :rtype: flytekit.common.tasks.task.SdkTask - """ - - @inputs(a=primitives.Integer) - @outputs(b=primitives.Integer) - @python_task() - def my_task(wf_params, a, b): - b.set(a + 1) - - return my_task - - -@_mock.patch("flytekit.clis.flyte_cli.main._load_proto_from_file") -def test__extract_files(load_mock): - t = get_sample_task() - with TemporaryConfiguration( - "", - internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, - ): - task_spec = t.serialize() - - load_mock.side_effect = [task_spec] - new_id, entity = _main._extract_pair("a", 1, "myproject", "development", "v", {}) - assert ( - new_id - == _core_identifier.Identifier( - _core_identifier.ResourceType.TASK, "myproject", "development", "test_flyte_cli.my_task", "v" - ).to_flyte_idl() - ) - assert task_spec == entity - - -@_mock.patch("flytekit.clis.flyte_cli.main._load_proto_from_file") +@_mock.patch("flytekit.clis.flyte_cli.main.utils") def test__extract_files_with_unspecified_resource_type(load_mock): id = _core_identifier.Identifier( _core_identifier.ResourceType.UNSPECIFIED, @@ -61,7 +24,7 @@ def test__extract_files_with_unspecified_resource_type(load_mock): "v", ) - load_mock.return_value = id.to_flyte_idl() + load_mock.load_proto_from_file.return_value = id.to_flyte_idl() with pytest.raises(FlyteAssertion): _main._extract_pair("a", "b", "myflyteproject", "development", "v", {}) diff --git a/tests/flytekit/unit/common_tests/__init__.py b/tests/flytekit/unit/common_tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/common_tests/exceptions/__init__.py b/tests/flytekit/unit/common_tests/exceptions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/common_tests/mixins/__init__.py b/tests/flytekit/unit/common_tests/mixins/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/common_tests/mixins/sample_registerable.py b/tests/flytekit/unit/common_tests/mixins/sample_registerable.py deleted file mode 100644 index b0ecb08ca9..0000000000 --- a/tests/flytekit/unit/common_tests/mixins/sample_registerable.py +++ /dev/null @@ -1,15 +0,0 @@ -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.mixins import registerable as _registerable - - -class ExampleRegisterable( - _registerable.RegisterableEntity, _registerable.TrackableEntity, metaclass=_sdk_bases.ExtendedSdkType -): - def __init__(self, *args, **kwargs): - super(ExampleRegisterable, self).__init__(*args, **kwargs) - - def promote_from_model(cls, base_model): - pass - - -example = ExampleRegisterable() diff --git a/tests/flytekit/unit/common_tests/mixins/test_registerable.py b/tests/flytekit/unit/common_tests/mixins/test_registerable.py deleted file mode 100644 index 030d5bc1e9..0000000000 --- a/tests/flytekit/unit/common_tests/mixins/test_registerable.py +++ /dev/null @@ -1,13 +0,0 @@ -from tests.flytekit.unit.common_tests.mixins import sample_registerable as _sample_registerable - - -def test_instance_tracker(): - assert _sample_registerable.example.instantiated_in == "tests.flytekit.unit.common_tests.mixins.sample_registerable" - - -def test_auto_name_assignment(): - _sample_registerable.example.auto_assign_name() - assert ( - _sample_registerable.example.platform_valid_name - == "tests.flytekit.unit.common_tests.mixins.sample_registerable.example" - ) diff --git a/tests/flytekit/unit/common_tests/tasks/__init__.py b/tests/flytekit/unit/common_tests/tasks/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/common_tests/tasks/spark/__init__.py b/tests/flytekit/unit/common_tests/tasks/spark/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/common_tests/tasks/spark/test_spark_task.py b/tests/flytekit/unit/common_tests/tasks/spark/test_spark_task.py deleted file mode 100644 index 9389fdc7ea..0000000000 --- a/tests/flytekit/unit/common_tests/tasks/spark/test_spark_task.py +++ /dev/null @@ -1,37 +0,0 @@ -from six.moves import range - -from flytekit.sdk.tasks import outputs, spark_task -from flytekit.sdk.types import Types - -# This file is in a subdirectory to make it easier to exclude when not running in a container -# and pyspark is not available - - -@outputs(out=Types.Integer) -@spark_task(retries=1) -def my_spark_task(wf, sc, out): - def _inside(p): - return p < 1000 - - count = sc.parallelize(range(0, 10000)).filter(_inside).count() - out.set(count) - - -@outputs(out=Types.Integer) -@spark_task(retries=3) -def my_spark_task2(wf, sc, out): - # This test makes sure spark_task doesn't choke on a non-package module and modules which overlap with auto-included - # modules. - def _inside(p): - return p < 500 - - count = sc.parallelize(range(0, 10000)).filter(_inside).count() - out.set(count) - - -def test_basic_spark_execution(): - outputs = my_spark_task.unit_test() - assert outputs["out"] == 1000 - - outputs = my_spark_task2.unit_test() - assert outputs["out"] == 500 diff --git a/tests/flytekit/unit/common_tests/tasks/test_execution_params.py b/tests/flytekit/unit/common_tests/tasks/test_execution_params.py deleted file mode 100644 index cf634bfa96..0000000000 --- a/tests/flytekit/unit/common_tests/tasks/test_execution_params.py +++ /dev/null @@ -1,76 +0,0 @@ -import os - -import py -import pytest - -from flytekit.common.tasks.sdk_runnable import SecretsManager -from flytekit.configuration import secrets - - -def test_secrets_manager_default(): - with pytest.raises(ValueError): - sec = SecretsManager() - sec.get("group", "key") - - -def test_secrets_manager_get_envvar(): - sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_env_var("test", "") - with pytest.raises(ValueError): - sec.get_secrets_env_var("", "x") - assert sec.get_secrets_env_var("group", "test") == f"{secrets.SECRETS_ENV_PREFIX.get()}GROUP_TEST" - - -def test_secrets_manager_get_file(): - sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_file("test", "") - with pytest.raises(ValueError): - sec.get_secrets_file("", "x") - assert sec.get_secrets_file("group", "test") == os.path.join( - secrets.SECRETS_DEFAULT_DIR.get(), - "group", - f"{secrets.SECRETS_FILE_PREFIX.get()}test", - ) - - -def test_secrets_manager_file(tmpdir: py.path.local): - tmp = tmpdir.mkdir("file_test").dirname - os.environ["FLYTE_SECRETS_DEFAULT_DIR"] = tmp - sec = SecretsManager() - f = os.path.join(tmp, "test") - with open(f, "w+") as w: - w.write("my-password") - - with pytest.raises(ValueError): - sec.get("test", "") - with pytest.raises(ValueError): - sec.get("", "x") - # Group dir not exists - with pytest.raises(ValueError): - sec.get("group", "test") - - g = os.path.join(tmp, "group") - os.makedirs(g) - f = os.path.join(g, "test") - with open(f, "w+") as w: - w.write("my-password") - assert sec.get("group", "test") == "my-password" - del os.environ["FLYTE_SECRETS_DEFAULT_DIR"] - - -def test_secrets_manager_bad_env(): - with pytest.raises(ValueError): - os.environ["TEST"] = "value" - sec = SecretsManager() - sec.get("group", "test") - - -def test_secrets_manager_env(): - sec = SecretsManager() - os.environ[sec.get_secrets_env_var("group", "test")] = "value" - assert sec.get("group", "test") == "value" - - os.environ[sec.get_secrets_env_var(group="group", key="key")] = "value" - assert sec.get(group="group", key="key") == "value" diff --git a/tests/flytekit/unit/common_tests/tasks/test_raw_container_task.py b/tests/flytekit/unit/common_tests/tasks/test_raw_container_task.py deleted file mode 100644 index 1267ce1c8b..0000000000 --- a/tests/flytekit/unit/common_tests/tasks/test_raw_container_task.py +++ /dev/null @@ -1,27 +0,0 @@ -from flytekit.common.tasks.raw_container import SdkRawContainerTask -from flytekit.sdk.types import Types - - -def test_raw_container_task_definition(): - tk = SdkRawContainerTask( - inputs={"x": Types.Integer}, - outputs={"y": Types.Integer}, - image="my-image", - command=["echo", "hello, world!"], - gpu_limit="1", - gpu_request="1", - ) - assert not tk.serialize() is None - - -def test_raw_container_task_definition_no_outputs(): - tk = SdkRawContainerTask( - inputs={"x": Types.Integer}, - image="my-image", - command=["echo", "hello, world!"], - gpu_limit="1", - gpu_request="1", - ) - assert not tk.serialize() is None - task_instance = tk(x=3) - assert task_instance.inputs[0].binding.scalar.primitive.integer == 3 diff --git a/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py b/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py deleted file mode 100644 index 43e677551c..0000000000 --- a/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest as _pytest - -from flytekit.common import constants as _common_constants -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import sdk_runnable -from flytekit.common.types import primitives -from flytekit.models import interface - - -def test_basic_unit_test(): - def add_one(wf_params, value_in, value_out): - value_out.set(value_in + 1) - - t = sdk_runnable.SdkRunnableTask( - add_one, - _common_constants.SdkTaskType.PYTHON_TASK, - "1", - 1, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - False, - None, - {}, - False, - None, - ) - t.add_inputs({"value_in": interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) - t.add_outputs({"value_out": interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) - out = t.unit_test(value_in=1) - assert out["value_out"] == 2 - - with _pytest.raises(_user_exceptions.FlyteAssertion) as e: - t() - - assert "value_in" in str(e.value) - assert "INTEGER" in str(e.value) diff --git a/tests/flytekit/unit/common_tests/tasks/test_task.py b/tests/flytekit/unit/common_tests/tasks/test_task.py deleted file mode 100644 index e33a757412..0000000000 --- a/tests/flytekit/unit/common_tests/tasks/test_task.py +++ /dev/null @@ -1,108 +0,0 @@ -import os as _os - -import pytest as _pytest -from flyteidl.admin import task_pb2 as _admin_task_pb2 -from mock import MagicMock as _MagicMock -from mock import patch as _patch - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import task as _task -from flytekit.common.tasks.presto_task import SdkPrestoTask -from flytekit.common.types import primitives -from flytekit.configuration import TemporaryConfiguration -from flytekit.models import task as _task_models -from flytekit.models.core import identifier as _identifier -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types - - -@_patch("flytekit.engines.flyte.engine._FlyteClientManager") -@_patch("flytekit.configuration.platform.URL") -def test_fetch_latest(mock_url, mock_client_manager): - mock_url.get.return_value = "localhost" - admin_task = _task_models.Task( - _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), - _MagicMock(), - ) - mock_client = _MagicMock() - mock_client.list_tasks_paginated = _MagicMock(return_value=([admin_task], "")) - mock_client_manager.return_value.client = mock_client - task = _task.SdkTask.fetch_latest("p1", "d1", "n1") - assert task.id == admin_task.id - - -@_patch("flytekit.engines.flyte.engine._FlyteClientManager") -@_patch("flytekit.configuration.platform.URL") -def test_fetch_latest_not_exist(mock_url, mock_client_manager): - mock_client = _MagicMock() - mock_client.list_tasks_paginated = _MagicMock(return_value=(None, "")) - mock_client_manager.return_value.client = mock_client - mock_url.get.return_value = "localhost" - with _pytest.raises(_user_exceptions.FlyteEntityNotExistException): - _task.SdkTask.fetch_latest("p1", "d1", "n1") - - -def get_sample_task(): - """ - :rtype: flytekit.common.tasks.task.SdkTask - """ - - @inputs(a=primitives.Integer) - @outputs(b=primitives.Integer) - @python_task() - def my_task(wf_params, a, b): - b.set(a + 1) - - return my_task - - -def test_task_serialization(): - t = get_sample_task() - with TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - "../../../common/configs/local.config", - ), - internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, - ): - s = t.serialize() - - assert isinstance(s, _admin_task_pb2.TaskSpec) - assert s.template.id.name == "tests.flytekit.unit.common_tests.tasks.test_task.my_task" - assert s.template.container.image == "myflyteimage:v123" - - -schema = Types.Schema([("a", Types.String), ("b", Types.Integer)]) - - -def test_task_produce_deterministic_version(): - containerless_task = SdkPrestoTask( - task_inputs=inputs(ds=Types.String, rg=Types.String), - statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 10", - output_schema=schema, - routing_group="{{ .Inputs.rg }}", - ) - identical_containerless_task = SdkPrestoTask( - task_inputs=inputs(ds=Types.String, rg=Types.String), - statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 10", - output_schema=schema, - routing_group="{{ .Inputs.rg }}", - ) - different_containerless_task = SdkPrestoTask( - task_inputs=inputs(ds=Types.String, rg=Types.String), - statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 100000", - output_schema=schema, - routing_group="{{ .Inputs.rg }}", - ) - assert ( - containerless_task._produce_deterministic_version() - == identical_containerless_task._produce_deterministic_version() - ) - - assert ( - containerless_task._produce_deterministic_version() - != different_containerless_task._produce_deterministic_version() - ) - - with _pytest.raises(Exception): - get_sample_task()._produce_deterministic_version() diff --git a/tests/flytekit/unit/common_tests/test_interface.py b/tests/flytekit/unit/common_tests/test_interface.py deleted file mode 100644 index b6627c1a6b..0000000000 --- a/tests/flytekit/unit/common_tests/test_interface.py +++ /dev/null @@ -1,80 +0,0 @@ -import pytest - -from flytekit.common import interface -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import containers, primitives - - -def test_binding_data_primitive_static(): - upstream_nodes = set() - bd = interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), 3.0, upstream_nodes=upstream_nodes - ) - - assert len(upstream_nodes) == 0 - assert bd.promise is None - assert bd.collection is None - assert bd.map is None - assert bd.scalar.primitive.float_value == 3.0 - - assert interface.BindingData.from_flyte_idl(bd.to_flyte_idl()) == bd - - with pytest.raises(_user_exceptions.FlyteTypeException): - interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), - "abc", - ) - - with pytest.raises(_user_exceptions.FlyteTypeException): - interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), - [1.0, 2.0, 3.0], - ) - - -def test_binding_data_list_static(): - upstream_nodes = set() - bd = interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), - ["abc", "cde"], - upstream_nodes=upstream_nodes, - ) - - assert len(upstream_nodes) == 0 - assert bd.promise is None - assert bd.collection.bindings[0].scalar.primitive.string_value == "abc" - assert bd.collection.bindings[1].scalar.primitive.string_value == "cde" - assert bd.map is None - assert bd.scalar is None - - assert interface.BindingData.from_flyte_idl(bd.to_flyte_idl()) == bd - - with pytest.raises(_user_exceptions.FlyteTypeException): - interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), - "abc", - ) - - with pytest.raises(_user_exceptions.FlyteTypeException): - interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), [1.0, 2.0, 3.0] - ) - - -def test_binding_generic_map_static(): - upstream_nodes = set() - bd = interface.BindingData.from_python_std( - primitives.Generic.to_flyte_literal_type(), - {"a": "hi", "b": [1, 2, 3], "c": {"d": "e"}}, - upstream_nodes=upstream_nodes, - ) - - assert len(upstream_nodes) == 0 - assert bd.promise is None - assert bd.map is None - assert bd.scalar.generic["a"] == "hi" - assert bd.scalar.generic["b"].values[0].number_value == 1.0 - assert bd.scalar.generic["b"].values[1].number_value == 2.0 - assert bd.scalar.generic["b"].values[2].number_value == 3.0 - assert bd.scalar.generic["c"]["d"] == "e" - assert interface.BindingData.from_flyte_idl(bd.to_flyte_idl()) == bd diff --git a/tests/flytekit/unit/common_tests/test_launch_plan.py b/tests/flytekit/unit/common_tests/test_launch_plan.py deleted file mode 100644 index 92943d37f1..0000000000 --- a/tests/flytekit/unit/common_tests/test_launch_plan.py +++ /dev/null @@ -1,388 +0,0 @@ -import os as _os - -import pytest as _pytest - -from flytekit import configuration as _configuration -from flytekit.common import launch_plan as _launch_plan -from flytekit.common import notifications as _notifications -from flytekit.common import schedules as _schedules -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import common as _common_models -from flytekit.models import schedule as _schedule -from flytekit.models import types as _type_models -from flytekit.models.core import execution as _execution -from flytekit.models.core import identifier as _identifier -from flytekit.sdk import types as _types -from flytekit.sdk import workflow as _workflow - - -def test_default_assumable_iam_role(): - with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - "../../common/configs/local.config", - ) - ): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan() - assert lp.auth_role.assumable_iam_role == "arn:aws:iam::ABC123:role/my-flyte-role" - - -def test_hard_coded_assumable_iam_role(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan(assumable_iam_role="override") - assert lp.auth_role.assumable_iam_role == "override" - - -def test_default_deprecated_role(): - with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - "../../common/configs/deprecated_local.config", - ) - ): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan() - assert lp.auth_role.assumable_iam_role == "arn:aws:iam::ABC123:role/my-flyte-role" - - -def test_hard_coded_deprecated_role(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan(role="override") - assert lp.auth_role.assumable_iam_role == "override" - - -def test_kubernetes_service_account(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan(kubernetes_service_account="kube-service-acct") - assert lp.auth_role.kubernetes_service_account == "kube-service-acct" - - -def test_fixed_inputs(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 4}) - assert len(lp.fixed_inputs.literals) == 1 - assert lp.fixed_inputs.literals["required_input"].scalar.primitive.integer == 4 - assert len(lp.default_inputs.parameters) == 1 - assert lp.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 - - -def test_redefining_inputs_good(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan( - default_inputs={"required_input": _workflow.Input(_types.Types.Integer, default=900)} - ) - assert len(lp.fixed_inputs.literals) == 0 - assert len(lp.default_inputs.parameters) == 2 - assert lp.default_inputs.parameters["required_input"].default.scalar.primitive.integer == 900 - assert lp.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 - - -def test_no_additional_inputs(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan() - assert len(lp.fixed_inputs.literals) == 0 - assert lp.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 - assert lp.default_inputs.parameters["required_input"].required is True - - -@_pytest.mark.parametrize( - "schedule,cron_expression,cron_schedule", - [ - (_schedules.CronSchedule("* * ? * * *"), "* * ? * * *", None), - (_schedules.CronSchedule(cron_expression="* * ? * * *"), "* * ? * * *", None), - (_schedules.CronSchedule(cron_expression="0/15 * * * ? *"), "0/15 * * * ? *", None), - (_schedules.CronSchedule(schedule="* * * * *"), None, _schedule.Schedule.CronSchedule("* * * * *", None)), - ( - _schedules.CronSchedule(schedule="* * * * *", offset="P1D"), - None, - _schedule.Schedule.CronSchedule("* * * * *", "P1D"), - ), - ], -) -def test_schedule(schedule, cron_expression, cron_schedule): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, - schedule=schedule, - role="what", - ) - assert lp.entity_metadata.schedule.kickoff_time_input_arg is None - assert lp.entity_metadata.schedule.cron_expression == cron_expression - assert lp.entity_metadata.schedule.cron_schedule == cron_schedule - assert lp.is_scheduled - - -def test_no_schedule(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan() - assert lp.entity_metadata.schedule.kickoff_time_input_arg == "" - assert lp.entity_metadata.schedule.schedule_expression is None - assert not lp.is_scheduled - - -def test_schedule_pointing_to_datetime(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Datetime), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan( - schedule=_schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="required_input"), - role="what", - ) - assert lp.entity_metadata.schedule.kickoff_time_input_arg == "required_input" - assert lp.entity_metadata.schedule.cron_expression == "* * ? * * *" - - -def test_notifications(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan( - notifications=[_notifications.PagerDuty([_execution.WorkflowExecutionPhase.FAILED], ["me@myplace.com"])] - ) - assert len(lp.entity_metadata.notifications) == 1 - assert lp.entity_metadata.notifications[0].pager_duty.recipients_email == ["me@myplace.com"] - assert lp.entity_metadata.notifications[0].phases == [_execution.WorkflowExecutionPhase.FAILED] - - -def test_no_notifications(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan() - assert len(lp.entity_metadata.notifications) == 0 - - -def test_launch_plan_node(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - outputs={"out": _workflow.Output([1, 2, 3], sdk_type=[_types.Types.Integer])}, - ) - lp = workflow_to_test.create_launch_plan() - - # Test that required input isn't set - with _pytest.raises(_user_exceptions.FlyteAssertion): - lp() - - # Test that positional args are rejected - with _pytest.raises(_user_exceptions.FlyteAssertion): - lp(1, 2) - - # Test that type checking works - with _pytest.raises(_user_exceptions.FlyteTypeException): - lp(required_input="abc", default_input=1) - - # Test that bad arg name is detected - with _pytest.raises(_user_exceptions.FlyteAssertion): - lp(required_input=1, bad_arg=1) - - # Test default input is accounted for - n = lp(required_input=10) - assert n.inputs[0].var == "default_input" - assert n.inputs[0].binding.scalar.primitive.integer == 5 - assert n.inputs[1].var == "required_input" - assert n.inputs[1].binding.scalar.primitive.integer == 10 - - # Test default input is overridden - n = lp(required_input=10, default_input=50) - assert n.inputs[0].var == "default_input" - assert n.inputs[0].binding.scalar.primitive.integer == 50 - assert n.inputs[1].var == "required_input" - assert n.inputs[1].binding.scalar.primitive.integer == 10 - - # Test that launch plan ID ref is flexible - lp._id = "fake" - assert n.workflow_node.launchplan_ref == "fake" - lp._id = None - - # Test that outputs are promised - n.assign_id_and_return("node-id") - assert n.outputs["out"].sdk_type.to_flyte_literal_type().collection_type.simple == _type_models.SimpleType.INTEGER - assert n.outputs["out"].var == "out" - assert n.outputs["out"].node_id == "node-id" - - -def test_labels(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, - schedule=_schedules.CronSchedule("* * ? * * *"), - role="what", - labels=_common_models.Labels({"my": "label"}), - ) - assert lp.labels.values == {"my": "label"} - - -def test_annotations(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, - schedule=_schedules.CronSchedule("* * ? * * *"), - role="what", - annotations=_common_models.Annotations({"my": "annotation"}), - ) - assert lp.annotations.values == {"my": "annotation"} - - -def test_serialize(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - workflow_to_test.id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v") - lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, - role="iam_role", - ) - - with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - "../../common/configs/local.config", - ), - internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, - ): - s = lp.serialize() - - assert ( - s.spec.workflow_id - == _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v").to_flyte_idl() - ) - assert s.spec.auth_role.assumable_iam_role == "iam_role" - assert s.spec.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 - - -def test_promote_from_model(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - workflow_to_test.id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v") - lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, - schedule=_schedules.CronSchedule("* * ? * * *"), - role="what", - labels=_common_models.Labels({"my": "label"}), - ) - - with _pytest.raises(_user_exceptions.FlyteAssertion): - _launch_plan.SdkRunnableLaunchPlan.from_flyte_idl(lp.to_flyte_idl()) - - lp_from_spec = _launch_plan.SdkLaunchPlan.from_flyte_idl(lp.to_flyte_idl()) - assert not isinstance(lp_from_spec, _launch_plan.SdkRunnableLaunchPlan) - assert isinstance(lp_from_spec, _launch_plan.SdkLaunchPlan) - assert lp_from_spec == lp - - -def test_raw_data_output_prefix(): - workflow_to_test = _workflow.workflow( - {}, - inputs={ - "required_input": _workflow.Input(_types.Types.Integer), - "default_input": _workflow.Input(_types.Types.Integer, default=5), - }, - ) - lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, - raw_output_data_prefix="s3://bucket-name", - ) - assert lp.raw_output_data_config.output_location_prefix == "s3://bucket-name" - - lp2 = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, - ) - assert lp2.raw_output_data_config.output_location_prefix == "" diff --git a/tests/flytekit/unit/common_tests/test_nodes.py b/tests/flytekit/unit/common_tests/test_nodes.py deleted file mode 100644 index 0c462c552c..0000000000 --- a/tests/flytekit/unit/common_tests/test_nodes.py +++ /dev/null @@ -1,292 +0,0 @@ -import datetime as _datetime - -import pytest as _pytest - -from flytekit.common import component_nodes as _component_nodes -from flytekit.common import interface as _interface -from flytekit.common import nodes as _nodes -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.models import literals as _literals -from flytekit.models.core import identifier as _identifier -from flytekit.models.core import workflow as _core_workflow_models -from flytekit.sdk import tasks as _tasks -from flytekit.sdk import types as _types -from flytekit.sdk import workflow as _workflow - - -def test_sdk_node_from_task(): - @_tasks.inputs(a=_types.Types.Integer) - @_tasks.outputs(b=_types.Types.Integer) - @_tasks.python_task() - def testy_test(wf_params, a, b): - pass - - n = _nodes.SdkNode( - "n", - [], - [ - _literals.Binding( - "a", - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), - ) - ], - _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), - sdk_task=testy_test, - sdk_workflow=None, - sdk_launch_plan=None, - sdk_branch=None, - ) - - assert n.id == "n" - assert len(n.inputs) == 1 - assert n.inputs[0].var == "a" - assert n.inputs[0].binding.scalar.primitive.integer == 3 - assert len(n.outputs) == 1 - assert "b" in n.outputs - assert n.outputs["b"].node_id == "n" - assert n.outputs["b"].var == "b" - assert n.outputs["b"].sdk_node == n - assert n.outputs["b"].sdk_type == _types.Types.Integer - assert n.metadata.name == "abc" - assert n.metadata.retries.retries == 3 - assert n.metadata.interruptible is None - assert len(n.upstream_nodes) == 0 - assert len(n.upstream_node_ids) == 0 - assert len(n.output_aliases) == 0 - - n2 = _nodes.SdkNode( - "n2", - [n], - [ - _literals.Binding( - "a", - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), n.outputs.b), - ) - ], - _core_workflow_models.NodeMetadata("abc2", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), - sdk_task=testy_test, - sdk_workflow=None, - sdk_launch_plan=None, - sdk_branch=None, - ) - - assert n2.id == "n2" - assert len(n2.inputs) == 1 - assert n2.inputs[0].var == "a" - assert n2.inputs[0].binding.promise.var == "b" - assert n2.inputs[0].binding.promise.node_id == "n" - assert len(n2.outputs) == 1 - assert "b" in n2.outputs - assert n2.outputs["b"].node_id == "n2" - assert n2.outputs["b"].var == "b" - assert n2.outputs["b"].sdk_node == n2 - assert n2.outputs["b"].sdk_type == _types.Types.Integer - assert n2.metadata.name == "abc2" - assert n2.metadata.retries.retries == 3 - assert "n" in n2.upstream_node_ids - assert n in n2.upstream_nodes - assert len(n2.upstream_nodes) == 1 - assert len(n2.upstream_node_ids) == 1 - assert len(n2.output_aliases) == 0 - - # Test right shift operator and late binding - n3 = _nodes.SdkNode( - "n3", - [], - [ - _literals.Binding( - "a", - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), - ) - ], - _core_workflow_models.NodeMetadata("abc3", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), - sdk_task=testy_test, - sdk_workflow=None, - sdk_launch_plan=None, - sdk_branch=None, - ) - n2 >> n3 - n >> n2 >> n3 - n3 << n2 - n3 << n2 << n - - assert n3.id == "n3" - assert len(n3.inputs) == 1 - assert n3.inputs[0].var == "a" - assert n3.inputs[0].binding.scalar.primitive.integer == 3 - assert len(n3.outputs) == 1 - assert "b" in n3.outputs - assert n3.outputs["b"].node_id == "n3" - assert n3.outputs["b"].var == "b" - assert n3.outputs["b"].sdk_node == n3 - assert n3.outputs["b"].sdk_type == _types.Types.Integer - assert n3.metadata.name == "abc3" - assert n3.metadata.retries.retries == 3 - assert "n2" in n3.upstream_node_ids - assert n2 in n3.upstream_nodes - assert len(n3.upstream_nodes) == 1 - assert len(n3.upstream_node_ids) == 1 - assert len(n3.output_aliases) == 0 - - # Test left shift operator and late binding - n4 = _nodes.SdkNode( - "n4", - [], - [ - _literals.Binding( - "a", - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), - ) - ], - _core_workflow_models.NodeMetadata("abc4", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), - sdk_task=testy_test, - sdk_workflow=None, - sdk_launch_plan=None, - sdk_branch=None, - ) - - n4 << n3 - - # Test that implicit dependencies don't cause direct dependencies - n4 << n3 << n2 << n - n >> n2 >> n3 >> n4 - - assert n4.id == "n4" - assert len(n4.inputs) == 1 - assert n4.inputs[0].var == "a" - assert n4.inputs[0].binding.scalar.primitive.integer == 3 - assert len(n4.outputs) == 1 - assert "b" in n4.outputs - assert n4.outputs["b"].node_id == "n4" - assert n4.outputs["b"].var == "b" - assert n4.outputs["b"].sdk_node == n4 - assert n4.outputs["b"].sdk_type == _types.Types.Integer - assert n4.metadata.name == "abc4" - assert n4.metadata.retries.retries == 3 - assert "n3" in n4.upstream_node_ids - assert n3 in n4.upstream_nodes - assert len(n4.upstream_nodes) == 1 - assert len(n4.upstream_node_ids) == 1 - assert len(n4.output_aliases) == 0 - - # Add another dependency - n4 << n2 - assert "n3" in n4.upstream_node_ids - assert n3 in n4.upstream_nodes - assert "n2" in n4.upstream_node_ids - assert n2 in n4.upstream_nodes - assert len(n4.upstream_nodes) == 2 - assert len(n4.upstream_node_ids) == 2 - - -def test_sdk_task_node(): - @_tasks.inputs(a=_types.Types.Integer) - @_tasks.outputs(b=_types.Types.Integer) - @_tasks.python_task() - def testy_test(wf_params, a, b): - pass - - testy_test._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") - n = _component_nodes.SdkTaskNode(testy_test) - assert n.reference_id.project == "project" - assert n.reference_id.domain == "domain" - assert n.reference_id.name == "name" - assert n.reference_id.version == "version" - - # Test floating ID - testy_test._id = _identifier.Identifier( - _identifier.ResourceType.TASK, - "new_project", - "new_domain", - "new_name", - "new_version", - ) - assert n.reference_id.project == "new_project" - assert n.reference_id.domain == "new_domain" - assert n.reference_id.name == "new_name" - assert n.reference_id.version == "new_version" - - -def test_sdk_node_from_lp(): - @_tasks.inputs(a=_types.Types.Integer) - @_tasks.outputs(b=_types.Types.Integer) - @_tasks.python_task() - def testy_test(wf_params, a, b): - pass - - @_workflow.workflow_class - class test_workflow(object): - a = _workflow.Input(_types.Types.Integer) - test = testy_test(a=a) - b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer) - - lp = test_workflow.create_launch_plan() - - n1 = _nodes.SdkNode( - "n1", - [], - [ - _literals.Binding( - "a", - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), - ) - ], - _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), - sdk_launch_plan=lp, - ) - - assert n1.id == "n1" - assert len(n1.inputs) == 1 - assert n1.inputs[0].var == "a" - assert n1.inputs[0].binding.scalar.primitive.integer == 3 - assert len(n1.outputs) == 1 - assert "b" in n1.outputs - assert n1.outputs["b"].node_id == "n1" - assert n1.outputs["b"].var == "b" - assert n1.outputs["b"].sdk_node == n1 - assert n1.outputs["b"].sdk_type == _types.Types.Integer - assert n1.metadata.name == "abc" - assert n1.metadata.retries.retries == 3 - assert len(n1.upstream_nodes) == 0 - assert len(n1.upstream_node_ids) == 0 - assert len(n1.output_aliases) == 0 - - -def test_sdk_launch_plan_node(): - @_tasks.inputs(a=_types.Types.Integer) - @_tasks.outputs(b=_types.Types.Integer) - @_tasks.python_task() - def testy_test(wf_params, a, b): - pass - - @_workflow.workflow_class - class test_workflow(object): - a = _workflow.Input(_types.Types.Integer) - test = testy_test(a=1) - b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer) - - lp = test_workflow.create_launch_plan() - - lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") - n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp) - assert n.launchplan_ref.project == "project" - assert n.launchplan_ref.domain == "domain" - assert n.launchplan_ref.name == "name" - assert n.launchplan_ref.version == "version" - - # Test floating ID - lp._id = _identifier.Identifier( - _identifier.ResourceType.TASK, - "new_project", - "new_domain", - "new_name", - "new_version", - ) - assert n.launchplan_ref.project == "new_project" - assert n.launchplan_ref.domain == "new_domain" - assert n.launchplan_ref.name == "new_name" - assert n.launchplan_ref.version == "new_version" - - # If you specify both, you should get an exception - with _pytest.raises(_system_exceptions.FlyteSystemException): - _component_nodes.SdkWorkflowNode(sdk_workflow=test_workflow, sdk_launch_plan=lp) diff --git a/tests/flytekit/unit/common_tests/test_notifications.py b/tests/flytekit/unit/common_tests/test_notifications.py deleted file mode 100644 index 8a2278097a..0000000000 --- a/tests/flytekit/unit/common_tests/test_notifications.py +++ /dev/null @@ -1,35 +0,0 @@ -from flytekit.common import notifications as _notifications -from flytekit.models.core import execution as _execution_model - - -def test_pager_duty(): - obj = _notifications.PagerDuty([_execution_model.WorkflowExecutionPhase.FAILED], ["me@myplace.com"]) - assert obj.email is None - assert obj.slack is None - assert obj.phases == [_execution_model.WorkflowExecutionPhase.FAILED] - assert obj.pager_duty.recipients_email == ["me@myplace.com"] - - obj2 = _notifications.PagerDuty.from_flyte_idl(obj.to_flyte_idl()) - assert obj == obj2 - - -def test_slack(): - obj = _notifications.Slack([_execution_model.WorkflowExecutionPhase.FAILED], ["me@myplace.com"]) - assert obj.email is None - assert obj.pager_duty is None - assert obj.phases == [_execution_model.WorkflowExecutionPhase.FAILED] - assert obj.slack.recipients_email == ["me@myplace.com"] - - obj2 = _notifications.Slack.from_flyte_idl(obj.to_flyte_idl()) - assert obj == obj2 - - -def test_email(): - obj = _notifications.Email([_execution_model.WorkflowExecutionPhase.FAILED], ["me@myplace.com"]) - assert obj.pager_duty is None - assert obj.slack is None - assert obj.phases == [_execution_model.WorkflowExecutionPhase.FAILED] - assert obj.email.recipients_email == ["me@myplace.com"] - - obj2 = _notifications.Email.from_flyte_idl(obj.to_flyte_idl()) - assert obj == obj2 diff --git a/tests/flytekit/unit/common_tests/test_promise.py b/tests/flytekit/unit/common_tests/test_promise.py deleted file mode 100644 index e06cca7dfa..0000000000 --- a/tests/flytekit/unit/common_tests/test_promise.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest - -from flytekit import FlyteContextManager -from flytekit.common import promise -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types, primitives -from flytekit.core.interface import Interface -from flytekit.core.promise import Promise, create_native_named_tuple, extract_obj_name -from flytekit.core.type_engine import TypeEngine -from flytekit.models.types import LiteralType, SimpleType - - -def test_input(): - i = promise.Input("name", primitives.Integer, help="blah", default=None) - assert i.name == "name" - assert i.sdk_default is None - assert i.default == base_sdk_types.Void() - assert i.sdk_required is False - assert i.help == "blah" - assert i.var.description == "blah" - assert i.sdk_type == primitives.Integer - - i = promise.Input("name2", primitives.Integer, default=1) - assert i.name == "name2" - assert i.sdk_default == 1 - assert i.default == primitives.Integer(1) - assert i.required is None - assert i.sdk_required is False - assert i.help is None - assert i.var.description == "" - assert i.sdk_type == primitives.Integer - - with pytest.raises(_user_exceptions.FlyteAssertion): - promise.Input("abc", primitives.Integer, required=True, default=1) - - -def test_create_native_named_tuple(): - ctx = FlyteContextManager.current_context() - t = create_native_named_tuple(ctx, promises=None, entity_interface=Interface()) - assert t is None - - p1 = Promise(var="x", val=TypeEngine.to_literal(ctx, 1, int, LiteralType(simple=SimpleType.INTEGER))) - p2 = Promise(var="y", val=TypeEngine.to_literal(ctx, 2, int, LiteralType(simple=SimpleType.INTEGER))) - - t = create_native_named_tuple(ctx, promises=p1, entity_interface=Interface(outputs={"x": int})) - assert t - assert t == 1 - - t = create_native_named_tuple(ctx, promises=[], entity_interface=Interface()) - assert t is None - - t = create_native_named_tuple(ctx, promises=[p1, p2], entity_interface=Interface(outputs={"x": int, "y": int})) - assert t - assert t == (1, 2) - - t = create_native_named_tuple( - ctx, promises=[p1, p2], entity_interface=Interface(outputs={"x": int, "y": int}, output_tuple_name="Tup") - ) - assert t - assert t == (1, 2) - assert t.__class__.__name__ == "Tup" - - with pytest.raises(KeyError): - create_native_named_tuple( - ctx, promises=[p1, p2], entity_interface=Interface(outputs={"x": int}, output_tuple_name="Tup") - ) - - with pytest.raises(AssertionError, match="Failed to convert value of output x"): - create_native_named_tuple(ctx, promises=[p1, p2], entity_interface=Interface(outputs={"x": Promise, "y": int})) - - with pytest.raises(AssertionError, match="Failed to convert value of output x"): - create_native_named_tuple(ctx, promises=p1, entity_interface=Interface(outputs={"x": Promise})) - - -@pytest.mark.parametrize( - "name, expected_name", - [ - ("test", "test"), - ("test.abc", "abc"), - (".test", "test"), - ("test.", ""), - ("test.xyz.abc", "abc"), - ("", ""), - (None, ""), - ("test.xyz.abc.", ""), - ], -) -def test_extract_obj_name(name, expected_name): - assert extract_obj_name(name) == expected_name diff --git a/tests/flytekit/unit/common_tests/test_schedules.py b/tests/flytekit/unit/common_tests/test_schedules.py deleted file mode 100644 index 4e6f231307..0000000000 --- a/tests/flytekit/unit/common_tests/test_schedules.py +++ /dev/null @@ -1,131 +0,0 @@ -import datetime as _datetime - -import pytest as _pytest - -from flytekit.common import schedules as _schedules -from flytekit.common.exceptions import user as _user_exceptions - - -def test_cron(): - obj = _schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="abc") - assert obj.kickoff_time_input_arg == "abc" - assert obj.cron_expression == "* * ? * * *" - assert obj == _schedules.CronSchedule.from_flyte_idl(obj.to_flyte_idl()) - - -def test_cron_karg(): - obj = _schedules.CronSchedule(cron_expression="* * ? * * *", kickoff_time_input_arg="abc") - assert obj.kickoff_time_input_arg == "abc" - assert obj.cron_expression == "* * ? * * *" - assert obj == _schedules.CronSchedule.from_flyte_idl(obj.to_flyte_idl()) - - -def test_cron_validation(): - with _pytest.raises(_user_exceptions.FlyteAssertion): - _schedules.CronSchedule("* * * * * *", kickoff_time_input_arg="abc") - - with _pytest.raises(_user_exceptions.FlyteAssertion): - _schedules.CronSchedule("* * ? * *", kickoff_time_input_arg="abc") - - -def test_fixed_rate(): - obj = _schedules.FixedRate(_datetime.timedelta(hours=10), kickoff_time_input_arg="abc") - assert obj.rate.unit == _schedules.FixedRate.FixedRateUnit.HOUR - assert obj.rate.value == 10 - assert obj == _schedules.FixedRate.from_flyte_idl(obj.to_flyte_idl()) - - obj = _schedules.FixedRate(_datetime.timedelta(hours=24), kickoff_time_input_arg="abc") - assert obj.rate.unit == _schedules.FixedRate.FixedRateUnit.DAY - assert obj.rate.value == 1 - assert obj == _schedules.FixedRate.from_flyte_idl(obj.to_flyte_idl()) - - obj = _schedules.FixedRate(_datetime.timedelta(minutes=30), kickoff_time_input_arg="abc") - assert obj.rate.unit == _schedules.FixedRate.FixedRateUnit.MINUTE - assert obj.rate.value == 30 - assert obj == _schedules.FixedRate.from_flyte_idl(obj.to_flyte_idl()) - - obj = _schedules.FixedRate(_datetime.timedelta(minutes=120), kickoff_time_input_arg="abc") - assert obj.rate.unit == _schedules.FixedRate.FixedRateUnit.HOUR - assert obj.rate.value == 2 - assert obj == _schedules.FixedRate.from_flyte_idl(obj.to_flyte_idl()) - - -def test_fixed_rate_bad_duration(): - pass - - -def test_fixed_rate_negative_duration(): - pass - - -@_pytest.mark.parametrize( - "schedule", - [ - "hourly", - "hours", - "HOURS", - "@hourly", - "daily", - "days", - "DAYS", - "@daily", - "weekly", - "weeks", - "WEEKS", - "@weekly", - "monthly", - "months", - "MONTHS", - "@monthly", - "annually", - "@annually", - "yearly", - "years", - "YEARS", - "@yearly", - "* * * * *", - ], -) -def test_cron_schedule_schedule_validation(schedule): - obj = _schedules.CronSchedule(schedule=schedule, kickoff_time_input_arg="abc") - assert obj.cron_schedule.schedule == schedule - - -@_pytest.mark.parametrize( - "schedule", - ["foo", "* *"], -) -def test_cron_schedule_schedule_validation_invalid(schedule): - with _pytest.raises(_user_exceptions.FlyteAssertion): - _schedules.CronSchedule(schedule=schedule, kickoff_time_input_arg="abc") - - -def test_cron_schedule_offset_validation_invalid(): - with _pytest.raises(_user_exceptions.FlyteAssertion): - _schedules.CronSchedule(schedule="days", offset="foo", kickoff_time_input_arg="abc") - - -def test_cron_schedule(): - obj = _schedules.CronSchedule(schedule="days", kickoff_time_input_arg="abc") - assert obj.cron_schedule.schedule == "days" - assert obj.cron_schedule.offset is None - assert obj == _schedules.CronSchedule.from_flyte_idl(obj.to_flyte_idl()) - - -def test_cron_schedule_offset(): - obj = _schedules.CronSchedule(schedule="days", offset="P1D", kickoff_time_input_arg="abc") - assert obj.cron_schedule.schedule == "days" - assert obj.cron_schedule.offset == "P1D" - assert obj == _schedules.CronSchedule.from_flyte_idl(obj.to_flyte_idl()) - - -def test_both_cron_expression_and_cron_schedule_schedule(): - with _pytest.raises(_user_exceptions.FlyteAssertion): - _schedules.CronSchedule( - cron_expression="* * ? * * *", schedule="days", offset="foo", kickoff_time_input_arg="abc" - ) - - -def test_cron_expression_and_cron_schedule_offset(): - with _pytest.raises(_user_exceptions.FlyteAssertion): - _schedules.CronSchedule(cron_expression="* * ? * * *", offset="foo", kickoff_time_input_arg="abc") diff --git a/tests/flytekit/unit/common_tests/test_workflow.py b/tests/flytekit/unit/common_tests/test_workflow.py deleted file mode 100644 index 13a9e65d8e..0000000000 --- a/tests/flytekit/unit/common_tests/test_workflow.py +++ /dev/null @@ -1,383 +0,0 @@ -import pytest as _pytest -from flyteidl.admin import workflow_pb2 as _workflow_pb2 - -from flytekit.common import constants, interface -from flytekit.common import local_workflow as _local_workflow -from flytekit.common import nodes, promise, workflow -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.local_workflow import build_sdk_workflow_from_metaclass -from flytekit.common.types import containers, primitives -from flytekit.models import literals as _literals -from flytekit.models.core import identifier as _identifier -from flytekit.models.core import workflow as _workflow_models -from flytekit.sdk import types as _types -from flytekit.sdk.tasks import inputs, outputs, python_task - - -def test_output(): - o = _local_workflow.Output("name", 1, sdk_type=primitives.Integer, help="blah") - assert o.name == "name" - assert o.var.description == "blah" - assert o.var.type == primitives.Integer.to_flyte_literal_type() - assert o.binding_data.scalar.primitive.integer == 1 - - -def test_workflow(): - @inputs(a=primitives.Integer) - @outputs(b=primitives.Integer) - @python_task() - def my_task(wf_params, a, b): - b.set(a + 1) - - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") - - @inputs(a=[primitives.Integer]) - @outputs(b=[primitives.Integer]) - @python_task - def my_list_task(wf_params, a, b): - b.set([v + 1 for v in a]) - - my_list_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version" - ) - - input_list = [ - promise.Input("input_1", primitives.Integer), - promise.Input("input_2", primitives.Integer, default=5, help="Not required."), - ] - - n1 = my_task(a=input_list[0]).assign_id_and_return("n1") - n2 = my_task(a=input_list[1]).assign_id_and_return("n2") - n3 = my_task(a=100).assign_id_and_return("n3") - n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") - n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return("n5") - n6 = my_list_task(a=n5.outputs.b) - n1 >> n6 - - nodes = [n1, n2, n3, n4, n5, n6] - - w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition( - inputs=input_list, - outputs=[_local_workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer)], - nodes=nodes, - ) - - assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == "a" - assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == "input_1" - assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == "input_2" - assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == "a" - assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id - - # Test conversion to flyte_idl and back - w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake", "faker", "fakest", "fakerest") - w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl()) - assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == "a" - assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == "input_1" - assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == "input_2" - assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == "a" - assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id - assert w.nodes[4].inputs[0].var == "a" - assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.var == "input_1" - assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.var == "input_2" - assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.node_id == n3.id - assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.var == "b" - assert w.nodes[4].inputs[0].binding.collection.bindings[3].scalar.primitive.integer == 100 - assert w.nodes[5].inputs[0].var == "a" - assert w.nodes[5].inputs[0].binding.promise.node_id == n5.id - assert w.nodes[5].inputs[0].binding.promise.var == "b" - - assert len(w.outputs) == 1 - assert w.outputs[0].var == "a" - assert w.outputs[0].binding.promise.var == "b" - assert w.outputs[0].binding.promise.node_id == "n1" - # TODO: Test promotion of w -> SdkWorkflow - - -def test_workflow_decorator(): - @inputs(a=primitives.Integer) - @outputs(b=primitives.Integer) - @python_task - def my_task(wf_params, a, b): - b.set(a + 1) - - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "propject", "domain", "my_task", "version") - - @inputs(a=[primitives.Integer]) - @outputs(b=[primitives.Integer]) - @python_task - def my_list_task(wf_params, a, b): - b.set([v + 1 for v in a]) - - my_list_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "propject", "domain", "my_list_task", "version" - ) - - class my_workflow(object): - input_1 = promise.Input("input_1", primitives.Integer) - input_2 = promise.Input("input_2", primitives.Integer, default=5, help="Not required.") - n1 = my_task(a=input_1) - n2 = my_task(a=input_2) - n3 = my_task(a=100) - n4 = my_task(a=n1.outputs.b) - n5 = my_list_task(a=[input_1, input_2, n3.outputs.b, 100]) - n6 = my_list_task(a=n5.outputs.b) - n1 >> n6 - a = _local_workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer) - - w = _local_workflow.build_sdk_workflow_from_metaclass( - my_workflow, - on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, - ) - - assert w.should_create_default_launch_plan is True - - assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == "a" - assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == "input_1" - assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == "input_2" - assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == "a" - assert w.nodes[3].inputs[0].binding.promise.node_id == "n1" - - # Test conversion to flyte_idl and back - w.id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake", "faker", "fakest", "fakerest") - w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl()) - assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == "a" - assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == "input_1" - assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == "input_2" - assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == "a" - assert w.nodes[3].inputs[0].binding.promise.node_id == "n1" - assert w.nodes[4].inputs[0].var == "a" - assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.var == "input_1" - assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.var == "input_2" - assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.node_id == "n3" - assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.var == "b" - assert w.nodes[4].inputs[0].binding.collection.bindings[3].scalar.primitive.integer == 100 - assert w.nodes[5].inputs[0].var == "a" - assert w.nodes[5].inputs[0].binding.promise.node_id == "n5" - assert w.nodes[5].inputs[0].binding.promise.var == "b" - - assert len(w.outputs) == 1 - assert w.outputs[0].var == "a" - assert w.outputs[0].binding.promise.var == "b" - assert w.outputs[0].binding.promise.node_id == "n1" - assert ( - w.metadata.on_failure == _workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE - ) - # TODO: Test promotion of w -> SdkWorkflow - - -def test_workflow_node(): - @inputs(a=primitives.Integer) - @outputs(b=primitives.Integer) - @python_task() - def my_task(wf_params, a, b): - b.set(a + 1) - - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") - - @inputs(a=[primitives.Integer]) - @outputs(b=[primitives.Integer]) - @python_task - def my_list_task(wf_params, a, b): - b.set([v + 1 for v in a]) - - my_list_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version" - ) - - input_list = [ - promise.Input("required", primitives.Integer), - promise.Input("not_required", primitives.Integer, default=5, help="Not required."), - ] - - n1 = my_task(a=input_list[0]).assign_id_and_return("n1") - n2 = my_task(a=input_list[1]).assign_id_and_return("n2") - n3 = my_task(a=100).assign_id_and_return("n3") - n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") - n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return("n5") - n6 = my_list_task(a=n5.outputs.b) - - nodes = [n1, n2, n3, n4, n5, n6] - - wf_out = [ - _local_workflow.Output( - "nested_out", - [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], - sdk_type=[[primitives.Integer]], - ), - _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), - ] - - w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition( - inputs=input_list, outputs=wf_out, nodes=nodes - ) - - # Test that required input isn't set - with _pytest.raises(_user_exceptions.FlyteAssertion): - w() - - # Test that positional args are rejected - with _pytest.raises(_user_exceptions.FlyteAssertion): - w(1, 2) - - # Test that type checking works - with _pytest.raises(_user_exceptions.FlyteTypeException): - w(required="abc", not_required=1) - - # Test that bad arg name is detected - with _pytest.raises(_user_exceptions.FlyteAssertion): - w(required=1, bad_arg=1) - - # Test default input is accounted for - n = w(required=10) - assert n.inputs[0].var == "not_required" - assert n.inputs[0].binding.scalar.primitive.integer == 5 - assert n.inputs[1].var == "required" - assert n.inputs[1].binding.scalar.primitive.integer == 10 - - # Test default input is overridden - n = w(required=10, not_required=50) - assert n.inputs[0].var == "not_required" - assert n.inputs[0].binding.scalar.primitive.integer == 50 - assert n.inputs[1].var == "required" - assert n.inputs[1].binding.scalar.primitive.integer == 10 - - # Test that workflow is saved in the node - w.id = "fake" - assert n.workflow_node.sub_workflow_ref == "fake" - w.id = None - - # Test that outputs are promised - n.assign_id_and_return("node-id*") # dns'ified - assert n.outputs["scalar_out"].sdk_type.to_flyte_literal_type() == primitives.Integer.to_flyte_literal_type() - assert n.outputs["scalar_out"].var == "scalar_out" - assert n.outputs["scalar_out"].node_id == "node-id" - - assert ( - n.outputs["nested_out"].sdk_type.to_flyte_literal_type() - == containers.List(containers.List(primitives.Integer)).to_flyte_literal_type() - ) - assert n.outputs["nested_out"].var == "nested_out" - assert n.outputs["nested_out"].node_id == "node-id" - - -def test_non_system_nodes(): - @inputs(a=primitives.Integer) - @outputs(b=primitives.Integer) - @python_task() - def my_task(wf_params, a, b): - b.set(a + 1) - - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") - - required_input = promise.Input("required", primitives.Integer) - - n1 = my_task(a=required_input).assign_id_and_return("n1") - - n_start = nodes.SdkNode( - "start-node", - [], - [ - _literals.Binding( - "a", - interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), - ) - ], - None, - sdk_task=my_task, - sdk_workflow=None, - sdk_launch_plan=None, - sdk_branch=None, - ) - - non_system_nodes = workflow.SdkWorkflow.get_non_system_nodes([n1, n_start]) - assert len(non_system_nodes) == 1 - assert non_system_nodes[0].id == "n1" - - -def test_workflow_serialization(): - @inputs(a=primitives.Integer) - @outputs(b=primitives.Integer) - @python_task() - def my_task(wf_params, a, b): - b.set(a + 1) - - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") - - @inputs(a=[primitives.Integer]) - @outputs(b=[primitives.Integer]) - @python_task - def my_list_task(wf_params, a, b): - b.set([v + 1 for v in a]) - - my_list_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version" - ) - - input_list = [ - promise.Input("required", primitives.Integer), - promise.Input("not_required", primitives.Integer, default=5, help="Not required."), - ] - - n1 = my_task(a=input_list[0]).assign_id_and_return("n1") - n2 = my_task(a=input_list[1]).assign_id_and_return("n2") - n3 = my_task(a=100).assign_id_and_return("n3") - n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") - n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return("n5") - n6 = my_list_task(a=n5.outputs.b) - - nodes = [n1, n2, n3, n4, n5, n6] - - wf_out = [ - _local_workflow.Output( - "nested_out", - [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], - sdk_type=[[primitives.Integer]], - ), - _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), - ] - - w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition( - inputs=input_list, outputs=wf_out, nodes=nodes - ) - serialized = w.serialize() - assert isinstance(serialized, _workflow_pb2.WorkflowSpec) - assert len(serialized.template.nodes) == 6 - assert len(serialized.template.interface.inputs.variables.keys()) == 2 - assert len(serialized.template.interface.outputs.variables.keys()) == 2 - - -def test_workflow_disable_default_launch_plan(): - class MyWorkflow(object): - input_1 = promise.Input("input_1", primitives.Integer) - input_2 = promise.Input("input_2", primitives.Integer, default=5, help="Not required.") - - w = build_sdk_workflow_from_metaclass( - MyWorkflow, - disable_default_launch_plan=True, - ) - - assert w.should_create_default_launch_plan is False diff --git a/tests/flytekit/unit/common_tests/test_workflow_promote.py b/tests/flytekit/unit/common_tests/test_workflow_promote.py index f165f4c231..f10ff22f1a 100644 --- a/tests/flytekit/unit/common_tests/test_workflow_promote.py +++ b/tests/flytekit/unit/common_tests/test_workflow_promote.py @@ -3,23 +3,11 @@ from flyteidl.core import compiler_pb2 as _compiler_pb2 from flyteidl.core import workflow_pb2 as _workflow_pb2 -from mock import patch as _patch -from flytekit.common import workflow as _workflow_common -from flytekit.common.tasks import task as _task -from flytekit.models import interface as _interface from flytekit.models import literals as _literals from flytekit.models import task as _task_model -from flytekit.models import types as _types from flytekit.models.core import compiler as _compiler_model -from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow_model -from flytekit.sdk import tasks as _sdk_tasks -from flytekit.sdk import workflow as _sdk_workflow -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.types import Types as _Types -from flytekit.sdk.workflow import Input, Output, workflow_class def get_sample_node_metadata(node_id): @@ -108,60 +96,6 @@ class OneTaskWFForPromote(object): return wt -@_patch("flytekit.common.tasks.task.SdkTask.fetch") -def test_basic_workflow_promote(mock_task_fetch): - # This section defines a sample workflow from a user - @_sdk_tasks.inputs(a=_Types.Integer) - @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer) - @_sdk_tasks.python_task() - def demo_task_for_promote(wf_params, a, b, c): - b.set(a + 1) - c.set(a + 2) - - @_sdk_workflow.workflow_class() - class TestPromoteExampleWf(object): - wf_input = _sdk_workflow.Input(_Types.Integer, required=True) - my_task_node = demo_task_for_promote(a=wf_input) - wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b, sdk_type=_Types.Integer) - wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c, sdk_type=_Types.Integer) - - # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow - int_type = _types.LiteralType(_types.SimpleType.INTEGER) - task_interface = _interface.TypedInterface( - # inputs - {"a": _interface.Variable(int_type, "description1")}, - # outputs - {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, - ) - # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return - task_template = _task_model.TaskTemplate( - _identifier.Identifier( - _identifier.ResourceType.TASK, - "project", - "domain", - "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote", - "version", - ), - "python_container", - get_sample_task_metadata(), - task_interface, - custom={}, - container=get_sample_container(), - ) - sdk_promoted_task = _task.SdkTask.promote_from_model(task_template) - mock_task_fetch.return_value = sdk_promoted_task - workflow_template = get_workflow_template() - promoted_wf = _workflow_common.SdkWorkflow.promote_from_model(workflow_template) - - assert promoted_wf.interface.inputs["wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"] - assert promoted_wf.interface.outputs["wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"] - assert promoted_wf.interface.outputs["wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"] - - assert len(promoted_wf.nodes) == 1 - assert len(TestPromoteExampleWf.nodes) == 1 - assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[0].inputs[0] - - def get_compiled_workflow_closure(): """ :rtype: flytekit.models.core.compiler.CompiledWorkflowClosure @@ -174,38 +108,3 @@ def get_compiled_workflow_closure(): cwc_pb.ParseFromString(fh.read()) return _compiler_model.CompiledWorkflowClosure.from_flyte_idl(cwc_pb) - - -def test_subworkflow_promote(): - cwc = get_compiled_workflow_closure() - primary = cwc.primary - sub_workflow_map = {sw.template.id: sw.template for sw in cwc.sub_workflows} - task_map = {t.template.id: t.template for t in cwc.tasks} - promoted_wf = _workflow_common.SdkWorkflow.promote_from_model(primary.template, sub_workflow_map, task_map) - - # This file that the promoted_wf reads contains the compiled workflow closure protobuf retrieved from Admin - # after registering a workflow that basically looks like the one below. - - @inputs(num=Types.Integer) - @outputs(out=Types.Integer) - @python_task - def inner_task(wf_params, num, out): - wf_params.logging.info("Running inner task... setting output to input") - out.set(num) - - @workflow_class() - class IdentityWorkflow(object): - a = Input(Types.Integer, default=5, help="Input for inner workflow") - odd_nums_task = inner_task(num=a) - task_output = Output(odd_nums_task.outputs.out, sdk_type=Types.Integer) - - @workflow_class() - class StaticSubWorkflowCaller(object): - outer_a = Input(Types.Integer, default=5, help="Input for inner workflow") - identity_wf_execution = IdentityWorkflow(a=outer_a) - wf_output = Output(identity_wf_execution.outputs.task_output, sdk_type=Types.Integer) - - assert StaticSubWorkflowCaller.interface == promoted_wf.interface - assert StaticSubWorkflowCaller.nodes[0].id == promoted_wf.nodes[0].id - assert StaticSubWorkflowCaller.nodes[0].inputs == promoted_wf.nodes[0].inputs - assert StaticSubWorkflowCaller.outputs == promoted_wf.outputs diff --git a/tests/flytekit/unit/common_tests/types/impl/__init__.py b/tests/flytekit/unit/common_tests/types/impl/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/common_tests/types/impl/test_blobs.py b/tests/flytekit/unit/common_tests/types/impl/test_blobs.py deleted file mode 100644 index 3b61837e44..0000000000 --- a/tests/flytekit/unit/common_tests/types/impl/test_blobs.py +++ /dev/null @@ -1,332 +0,0 @@ -import os - -import pytest - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types.impl import blobs -from flytekit.common.utils import AutoDeletingTempDir -from flytekit.models.core import types as _core_types -from flytekit.sdk import test_utils - - -def test_blob(): - b = blobs.Blob("/tmp/fake") - assert b.remote_location == "/tmp/fake" - assert b.local_path is None - assert b.mode == "rb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - - -def test_blob_from_python_std(): - with test_utils.LocalTestFileSystem() as t: - with AutoDeletingTempDir("test") as wd: - tmp_name = wd.get_named_tempfile("from_python_std") - with open(tmp_name, "wb") as w: - w.write("hello hello".encode("utf-8")) - b = blobs.Blob.from_python_std(tmp_name) - assert b.mode == "wb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - assert b.remote_location.startswith(t.name) - assert b.local_path == tmp_name - with open(b.remote_location, "rb") as r: - assert r.read() == "hello hello".encode("utf-8") - - b = blobs.Blob("/tmp/fake") - b2 = blobs.Blob.from_python_std(b) - assert b == b2 - - with pytest.raises(_user_exceptions.FlyteTypeException): - blobs.Blob.from_python_std(3) - - -def test_blob_create_at(): - with test_utils.LocalTestFileSystem() as t: - with AutoDeletingTempDir("test") as wd: - tmp_name = wd.get_named_tempfile("tmp") - b = blobs.Blob.create_at_known_location(tmp_name) - assert b.local_path is None - assert b.remote_location == tmp_name - assert b.mode == "wb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - with b as w: - w.write("hello hello".encode("utf-8")) - - assert b.local_path.startswith(t.name) - with open(tmp_name, "rb") as r: - assert r.read() == "hello hello".encode("utf-8") - - -def test_blob_fetch_managed(): - with AutoDeletingTempDir("test") as wd: - with test_utils.LocalTestFileSystem() as t: - tmp_name = wd.get_named_tempfile("tmp") - with open(tmp_name, "wb") as w: - w.write("hello".encode("utf-8")) - - b = blobs.Blob.fetch(tmp_name) - assert b.local_path.startswith(t.name) - assert b.remote_location == tmp_name - assert b.mode == "rb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - with b as r: - assert r.read() == "hello".encode("utf-8") - - with pytest.raises(_user_exceptions.FlyteAssertion): - blobs.Blob.fetch(tmp_name, local_path=b.local_path) - - with open(tmp_name, "wb") as w: - w.write("bye".encode("utf-8")) - - b2 = blobs.Blob.fetch(tmp_name, local_path=b.local_path, overwrite=True) - with b2 as r: - assert r.read() == "bye".encode("utf-8") - - with pytest.raises(_user_exceptions.FlyteAssertion): - blobs.Blob.fetch(tmp_name) - - -def test_blob_fetch_unmanaged(): - with AutoDeletingTempDir("test") as wd: - with AutoDeletingTempDir("test2") as t: - tmp_name = wd.get_named_tempfile("source") - tmp_sink = t.get_named_tempfile("sink") - with open(tmp_name, "wb") as w: - w.write("hello".encode("utf-8")) - - b = blobs.Blob.fetch(tmp_name, local_path=tmp_sink) - assert b.local_path == tmp_sink - assert b.remote_location == tmp_name - assert b.mode == "rb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - with b as r: - assert r.read() == "hello".encode("utf-8") - - with pytest.raises(_user_exceptions.FlyteAssertion): - blobs.Blob.fetch(tmp_name, local_path=tmp_sink) - - with open(tmp_name, "wb") as w: - w.write("bye".encode("utf-8")) - - b2 = blobs.Blob.fetch(tmp_name, local_path=tmp_sink, overwrite=True) - with b2 as r: - assert r.read() == "bye".encode("utf-8") - - -def test_blob_double_enter(): - with test_utils.LocalTestFileSystem(): - with AutoDeletingTempDir("test") as wd: - b = blobs.Blob(wd.get_named_tempfile("sink"), mode="wb") - with b: - with pytest.raises(_user_exceptions.FlyteAssertion): - with b: - pass - - -def test_blob_download_managed(): - with AutoDeletingTempDir("test") as wd: - with test_utils.LocalTestFileSystem() as t: - tmp_name = wd.get_named_tempfile("tmp") - with open(tmp_name, "wb") as w: - w.write("hello".encode("utf-8")) - - b = blobs.Blob(tmp_name) - b.download() - assert b.local_path.startswith(t.name) - assert b.remote_location == tmp_name - assert b.mode == "rb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - with b as r: - assert r.read() == "hello".encode("utf-8") - - b2 = blobs.Blob(tmp_name) - with pytest.raises(_user_exceptions.FlyteAssertion): - b2.download(b.local_path) - - with open(tmp_name, "wb") as w: - w.write("bye".encode("utf-8")) - - b2 = blobs.Blob(tmp_name) - b2.download(local_path=b.local_path, overwrite=True) - with b2 as r: - assert r.read() == "bye".encode("utf-8") - - b = blobs.Blob(tmp_name) - with pytest.raises(_user_exceptions.FlyteAssertion): - b.download() - - -def test_blob_download_unmanaged(): - with AutoDeletingTempDir("test") as wd: - with AutoDeletingTempDir("test2") as t: - tmp_name = wd.get_named_tempfile("source") - tmp_sink = t.get_named_tempfile("sink") - with open(tmp_name, "wb") as w: - w.write("hello".encode("utf-8")) - - b = blobs.Blob(tmp_name) - b.download(tmp_sink) - assert b.local_path == tmp_sink - assert b.remote_location == tmp_name - assert b.mode == "rb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - with b as r: - assert r.read() == "hello".encode("utf-8") - - b = blobs.Blob(tmp_name) - with pytest.raises(_user_exceptions.FlyteAssertion): - b.download(tmp_sink) - - with open(tmp_name, "wb") as w: - w.write("bye".encode("utf-8")) - - b2 = blobs.Blob(tmp_name) - b2.download(tmp_sink, overwrite=True) - with b2 as r: - assert r.read() == "bye".encode("utf-8") - - -def test_multipart_blob(): - b = blobs.MultiPartBlob("/tmp/fake", mode="w", format="csv") - assert b.remote_location == "/tmp/fake/" - assert b.local_path is None - assert b.mode == "w" - assert b.metadata.type.format == "csv" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART - - -def _generate_multipart_blob_data(tmp_dir): - n = tmp_dir.get_named_tempfile("0") - with open(n, "wb") as w: - w.write("part0".encode("utf-8")) - n = tmp_dir.get_named_tempfile("1") - with open(n, "wb") as w: - w.write("part1".encode("utf-8")) - n = tmp_dir.get_named_tempfile("2") - with open(n, "wb") as w: - w.write("part2".encode("utf-8")) - - -def test_multipart_blob_from_python_std(): - with test_utils.LocalTestFileSystem() as t: - with AutoDeletingTempDir("test") as wd: - _generate_multipart_blob_data(wd) - b = blobs.MultiPartBlob.from_python_std(wd.name) - assert b.mode == "wb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART - assert b.remote_location.startswith(t.name) - assert b.local_path == wd.name - with open(os.path.join(b.remote_location, "0"), "rb") as r: - assert r.read() == "part0".encode("utf-8") - with open(os.path.join(b.remote_location, "1"), "rb") as r: - assert r.read() == "part1".encode("utf-8") - with open(os.path.join(b.remote_location, "2"), "rb") as r: - assert r.read() == "part2".encode("utf-8") - - b = blobs.MultiPartBlob("/tmp/fake/") - b2 = blobs.MultiPartBlob.from_python_std(b) - assert b == b2 - - with pytest.raises(_user_exceptions.FlyteTypeException): - blobs.MultiPartBlob.from_python_std(3) - - -def test_multipart_blob_create_at(): - with test_utils.LocalTestFileSystem(): - with AutoDeletingTempDir("test") as wd: - b = blobs.MultiPartBlob.create_at_known_location(wd.name) - assert b.local_path is None - assert b.remote_location == wd.name + "/" - assert b.mode == "wb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART - with b.create_part("0") as w: - w.write("part0".encode("utf-8")) - with b.create_part("1") as w: - w.write("part1".encode("utf-8")) - with b.create_part("2") as w: - w.write("part2".encode("utf-8")) - - with open(os.path.join(wd.name, "0"), "rb") as r: - assert r.read() == "part0".encode("utf-8") - with open(os.path.join(wd.name, "1"), "rb") as r: - assert r.read() == "part1".encode("utf-8") - with open(os.path.join(wd.name, "2"), "rb") as r: - assert r.read() == "part2".encode("utf-8") - - -def test_multipart_blob_fetch_managed(): - with AutoDeletingTempDir("test") as wd: - with test_utils.LocalTestFileSystem() as t: - _generate_multipart_blob_data(wd) - - b = blobs.MultiPartBlob.fetch(wd.name) - assert b.local_path.startswith(t.name) - assert b.remote_location == wd.name + "/" - assert b.mode == "rb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART - with b as r: - assert r[0].read() == "part0".encode("utf-8") - assert r[1].read() == "part1".encode("utf-8") - assert r[2].read() == "part2".encode("utf-8") - - with pytest.raises(_user_exceptions.FlyteAssertion): - blobs.MultiPartBlob.fetch(wd.name, local_path=b.local_path) - - with open(os.path.join(wd.name, "0"), "wb") as w: - w.write("bye".encode("utf-8")) - - b2 = blobs.MultiPartBlob.fetch(wd.name, local_path=b.local_path, overwrite=True) - with b2 as r: - assert r[0].read() == "bye".encode("utf-8") - assert r[1].read() == "part1".encode("utf-8") - assert r[2].read() == "part2".encode("utf-8") - - with pytest.raises(_user_exceptions.FlyteAssertion): - blobs.Blob.fetch(wd.name) - - -def test_multipart_blob_fetch_unmanaged(): - with AutoDeletingTempDir("test") as wd: - with AutoDeletingTempDir("test2") as t: - _generate_multipart_blob_data(wd) - tmp_sink = t.get_named_tempfile("sink") - - b = blobs.MultiPartBlob.fetch(wd.name, local_path=tmp_sink) - assert b.local_path == tmp_sink - assert b.remote_location == wd.name + "/" - assert b.mode == "rb" - assert b.metadata.type.format == "" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART - with b as r: - assert r[0].read() == "part0".encode("utf-8") - assert r[1].read() == "part1".encode("utf-8") - assert r[2].read() == "part2".encode("utf-8") - - with pytest.raises(_user_exceptions.FlyteAssertion): - blobs.MultiPartBlob.fetch(wd.name, local_path=tmp_sink) - - with open(os.path.join(wd.name, "0"), "wb") as w: - w.write("bye".encode("utf-8")) - - b2 = blobs.MultiPartBlob.fetch(wd.name, local_path=tmp_sink, overwrite=True) - with b2 as r: - assert r[0].read() == "bye".encode("utf-8") - assert r[1].read() == "part1".encode("utf-8") - assert r[2].read() == "part2".encode("utf-8") - - -def test_multipart_blob_no_enter_on_write(): - with test_utils.LocalTestFileSystem(): - b = blobs.MultiPartBlob.create_at_any_location() - with pytest.raises(_user_exceptions.FlyteAssertion): - with b: - pass diff --git a/tests/flytekit/unit/common_tests/types/impl/test_schema.py b/tests/flytekit/unit/common_tests/types/impl/test_schema.py deleted file mode 100644 index de5bba4a46..0000000000 --- a/tests/flytekit/unit/common_tests/types/impl/test_schema.py +++ /dev/null @@ -1,553 +0,0 @@ -import collections as _collections -import datetime as _datetime -import os as _os -import uuid as _uuid - -import pandas as _pd -import pytest as _pytest -import six.moves as _six_moves - -from flytekit.common import utils as _utils -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import blobs as _blobs -from flytekit.common.types import primitives as _primitives -from flytekit.common.types.impl import schema as _schema_impl -from flytekit.models import literals as _literal_models -from flytekit.models import types as _type_models -from flytekit.sdk import test_utils as _test_utils - - -def test_schema_type(): - _schema_impl.SchemaType() - _schema_impl.SchemaType([]) - _schema_impl.SchemaType( - [ - ("a", _primitives.Integer), - ("b", _primitives.String), - ("c", _primitives.Float), - ("d", _primitives.Boolean), - ("e", _primitives.Datetime), - ] - ) - - with _pytest.raises(ValueError): - _schema_impl.SchemaType({"a": _primitives.Integer}) - - with _pytest.raises(TypeError): - _schema_impl.SchemaType([("a", _blobs.Blob)]) - - with _pytest.raises(ValueError): - _schema_impl.SchemaType([("a", _primitives.Integer, 1)]) - - _schema_impl.SchemaType([("1", _primitives.Integer)]) - with _pytest.raises(TypeError): - _schema_impl.SchemaType([(1, _primitives.Integer)]) - - with _pytest.raises(TypeError): - _schema_impl.SchemaType([("1", [_primitives.Integer])]) - - -value_type_tuples = [ - ("abra", _primitives.Integer, [1, 2, 3, 4, 5]), - ("CADABRA", _primitives.Float, [1.0, 2.0, 3.0, 4.0, 5.0]), - ("HoCuS", _primitives.String, ["A", "B", "C", "D", "E"]), - ("Pocus", _primitives.Boolean, [True, False, True, False]), - ( - "locusts", - _primitives.Datetime, - [ - _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) - - _datetime.timedelta(days=i) - for i in _six_moves.range(5) - ], - ), -] - - -@_pytest.mark.parametrize("value_type_pair", value_type_tuples) -def test_simple_read_and_write_with_different_types(value_type_pair): - column_name, flyte_type, values = value_type_pair - values = [tuple([value]) for value in values] - schema_type = _schema_impl.SchemaType(columns=[(column_name, flyte_type)]) - - with _test_utils.LocalTestFileSystem() as sandbox: - with _utils.AutoDeletingTempDir("test") as t: - a = _schema_impl.Schema.create_at_known_location(t.name, mode="wb", schema_type=schema_type) - assert a.local_path is None - with a as writer: - for _ in _six_moves.range(5): - writer.write(_pd.DataFrame.from_records(values, columns=[column_name])) - assert a.local_path.startswith(sandbox.name) - assert a.local_path is None - - b = _schema_impl.Schema.create_at_known_location(t.name, mode="rb", schema_type=schema_type) - assert b.local_path is None - with b as reader: - for df in reader.iter_chunks(): - for check, actual in _six_moves.zip(values, df[column_name].tolist()): - assert check[0] == actual - assert reader.read() is None - reader.seek(0) - df = reader.read(concat=True) - for iter_count, actual in enumerate(df[column_name].tolist()): - assert values[iter_count % len(values)][0] == actual - assert b.local_path.startswith(sandbox.name) - assert b.local_path is None - - -def test_datetime_coercion_explicitly(): - """ - Sanity check that we're using a version of pyarrow that allows us to - truncate timestamps - """ - dt = _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) - values = [(dt,)] - df = _pd.DataFrame.from_records(values, columns=["testname"]) - assert df["testname"][0] == dt - - with _utils.AutoDeletingTempDir("test") as tmpdir: - tmpfile = tmpdir.get_named_tempfile("repro.parquet") - df.to_parquet(tmpfile, coerce_timestamps="ms", allow_truncated_timestamps=True) - df2 = _pd.read_parquet(tmpfile) - - dt2 = _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1) - assert df2["testname"][0] == dt2 - - -def test_datetime_coercion(): - values = [ - tuple( - [ - _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) - - _datetime.timedelta(days=x) - ] - ) - for x in _six_moves.range(5) - ] - schema_type = _schema_impl.SchemaType(columns=[("testname", _primitives.Datetime)]) - - with _test_utils.LocalTestFileSystem(): - with _utils.AutoDeletingTempDir("test") as t: - a = _schema_impl.Schema.create_at_known_location(t.name, mode="wb", schema_type=schema_type) - with a as writer: - for _ in _six_moves.range(5): - # us to ms coercion segfaults unless we explicitly allow truncation. - writer.write( - _pd.DataFrame.from_records(values, columns=["testname"]), - coerce_timestamps="ms", - allow_truncated_timestamps=True, - ) - - # TODO: Uncomment when segfault bug is resolved - # with _pytest.raises(Exception): - # writer.write( - # _pd.DataFrame.from_records(values, columns=['testname']), - # coerce_timestamps='ms') - - b = _schema_impl.Schema.create_at_known_location(t.name, mode="wb", schema_type=schema_type) - with b as writer: - for _ in _six_moves.range(5): - writer.write(_pd.DataFrame.from_records(values, columns=["testname"])) - - -@_pytest.mark.parametrize("value_type_pair", value_type_tuples) -def test_fetch(value_type_pair): - column_name, flyte_type, values = value_type_pair - values = [tuple([value]) for value in values] - schema_type = _schema_impl.SchemaType(columns=[(column_name, flyte_type)]) - - with _utils.AutoDeletingTempDir("test") as tmpdir: - for i in _six_moves.range(3): - _pd.DataFrame.from_records(values, columns=[column_name]).to_parquet( - tmpdir.get_named_tempfile(str(i).zfill(6)), coerce_timestamps="us" - ) - - with _utils.AutoDeletingTempDir("test2") as local_dir: - schema_obj = _schema_impl.Schema.fetch( - tmpdir.name, - local_path=local_dir.get_named_tempfile("schema_test"), - schema_type=schema_type, - ) - with schema_obj as reader: - for df in reader.iter_chunks(): - for check, actual in _six_moves.zip(values, df[column_name].tolist()): - assert check[0] == actual - assert reader.read() is None - reader.seek(0) - df = reader.read(concat=True) - for iter_count, actual in enumerate(df[column_name].tolist()): - assert values[iter_count % len(values)][0] == actual - - -@_pytest.mark.parametrize("value_type_pair", value_type_tuples) -def test_download(value_type_pair): - column_name, flyte_type, values = value_type_pair - values = [tuple([value]) for value in values] - schema_type = _schema_impl.SchemaType(columns=[(column_name, flyte_type)]) - - with _utils.AutoDeletingTempDir("test") as tmpdir: - for i in _six_moves.range(3): - _pd.DataFrame.from_records(values, columns=[column_name]).to_parquet( - tmpdir.get_named_tempfile(str(i).zfill(6)), coerce_timestamps="us" - ) - - with _utils.AutoDeletingTempDir("test2") as local_dir: - schema_obj = _schema_impl.Schema(tmpdir.name, schema_type=schema_type) - schema_obj.download(local_dir.get_named_tempfile(_uuid.uuid4().hex)) - with schema_obj as reader: - for df in reader.iter_chunks(): - for check, actual in _six_moves.zip(values, df[column_name].tolist()): - assert check[0] == actual - assert reader.read() is None - reader.seek(0) - df = reader.read(concat=True) - for iter_count, actual in enumerate(df[column_name].tolist()): - assert values[iter_count % len(values)][0] == actual - - with _pytest.raises(Exception): - schema_obj = _schema_impl.Schema(tmpdir.name, schema_type=schema_type) - schema_obj.download() - - with _test_utils.LocalTestFileSystem(): - schema_obj = _schema_impl.Schema(tmpdir.name, schema_type=schema_type) - schema_obj.download() - with schema_obj as reader: - for df in reader.iter_chunks(): - for check, actual in _six_moves.zip(values, df[column_name].tolist()): - assert check[0] == actual - assert reader.read() is None - reader.seek(0) - df = reader.read(concat=True) - for iter_count, actual in enumerate(df[column_name].tolist()): - assert values[iter_count % len(values)][0] == actual - - -def test_hive_queries(monkeypatch): - def return_deterministic_uuid(): - class FakeUUID4(object): - def __init__(self): - self.hex = "test_uuid" - - class Uuid(object): - def uuid4(self): - return FakeUUID4() - - return Uuid() - - monkeypatch.setattr(_schema_impl, "_uuid", return_deterministic_uuid()) - - all_types = _schema_impl.SchemaType( - [ - ("a", _primitives.Integer), - ("b", _primitives.String), - ("c", _primitives.Float), - ("d", _primitives.Boolean), - ("e", _primitives.Datetime), - ] - ) - - with _test_utils.LocalTestFileSystem(): - df, query = _schema_impl.Schema.create_from_hive_query( - "SELECT a, b, c, d, e FROM some_place WHERE i = 0", - stage_query="CREATE TEMPORARY TABLE some_place AS SELECT * FROM some_place_original", - known_location="s3://my_fixed_path/", - schema_type=all_types, - ) - - full_query = """ - CREATE TEMPORARY TABLE some_place AS SELECT * FROM some_place_original; - CREATE TEMPORARY TABLE test_uuid_tmp AS SELECT a, b, c, d, e FROM some_place WHERE i = 0; - CREATE EXTERNAL TABLE test_uuid LIKE test_uuid_tmp STORED AS PARQUET; - ALTER TABLE test_uuid SET LOCATION 's3://my_fixed_path/'; - INSERT OVERWRITE TABLE test_uuid - SELECT - a as a, - b as b, - CAST(c as double) c, - d as d, - e as e - FROM test_uuid_tmp; - DROP TABLE test_uuid; - """ - full_query = " ".join(full_query.split()) - query = " ".join(query.split()) - assert query == full_query - - # Test adding partition - full_query = """ - ALTER TABLE some_table ADD IF NOT EXISTS PARTITION ( - region = 'SEA', - ds = '2017-01-01' - ) LOCATION 's3://my_fixed_path/'; - ALTER TABLE some_table PARTITION ( - region = 'SEA', - ds = '2017-01-01' - ) SET LOCATION 's3://my_fixed_path/'; - """ - query = df.get_write_partition_to_hive_table_query( - "some_table", - partitions=_collections.OrderedDict([("region", "SEA"), ("ds", "2017-01-01")]), - ) - full_query = " ".join(full_query.split()) - query = " ".join(query.split()) - assert query == full_query - - -def test_partial_column_read(): - with _test_utils.LocalTestFileSystem(): - a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]) - ) - with a as writer: - writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) - - b = _schema_impl.Schema.fetch( - a.uri, - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - with b as reader: - df = reader.read(columns=["b"]) - assert df.columns.values == ["b"] - assert df["b"].tolist() == [5, 6, 7, 8] - - -def test_casting(): - pass - - -def test_from_python_std(): - with _test_utils.LocalTestFileSystem(): - - def single_dataframe(): - df1 = _pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) - s = _schema_impl.Schema.from_python_std( - t_value=df1, - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - assert s is not None - n = _schema_impl.Schema.fetch( - s.uri, - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - with n as reader: - df2 = reader.read() - assert df2.columns.values.all() == df1.columns.values.all() - assert df2["b"].tolist() == df1["b"].tolist() - - def list_of_dataframes(): - df1 = _pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) - df2 = _pd.DataFrame.from_dict({"a": [9, 10, 11, 12], "b": [13, 14, 15, 16]}) - s = _schema_impl.Schema.from_python_std( - t_value=[df1, df2], - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - assert s is not None - n = _schema_impl.Schema.fetch( - s.uri, - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - with n as reader: - actual = [] - for df in reader.iter_chunks(): - assert df.columns.values.all() == df1.columns.values.all() - actual.extend(df["b"].tolist()) - b_val = df1["b"].tolist() - b_val.extend(df2["b"].tolist()) - assert actual == b_val - - def mixed_list(): - df1 = _pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) - df2 = [1, 2, 3] - with _pytest.raises(_user_exceptions.FlyteTypeException): - _schema_impl.Schema.from_python_std( - t_value=[df1, df2], - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - - def empty_list(): - s = _schema_impl.Schema.from_python_std( - t_value=[], - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - assert s is not None - n = _schema_impl.Schema.fetch( - s.uri, - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), - ) - with n as reader: - df = reader.read() - assert df is None - - single_dataframe() - mixed_list() - empty_list() - list_of_dataframes() - - -def test_promote_from_model_schema_type(): - m = _type_models.SchemaType( - [ - _type_models.SchemaType.SchemaColumn("a", _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - _type_models.SchemaType.SchemaColumn("b", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _type_models.SchemaType.SchemaColumn("c", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _type_models.SchemaType.SchemaColumn("d", _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _type_models.SchemaType.SchemaColumn("e", _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _type_models.SchemaType.SchemaColumn("f", _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING), - ] - ) - s = _schema_impl.SchemaType.promote_from_model(m) - assert s.columns == m.columns - assert s.sdk_columns["a"].to_flyte_literal_type() == _primitives.Boolean.to_flyte_literal_type() - assert s.sdk_columns["b"].to_flyte_literal_type() == _primitives.Datetime.to_flyte_literal_type() - assert s.sdk_columns["c"].to_flyte_literal_type() == _primitives.Timedelta.to_flyte_literal_type() - assert s.sdk_columns["d"].to_flyte_literal_type() == _primitives.Float.to_flyte_literal_type() - assert s.sdk_columns["e"].to_flyte_literal_type() == _primitives.Integer.to_flyte_literal_type() - assert s.sdk_columns["f"].to_flyte_literal_type() == _primitives.String.to_flyte_literal_type() - assert s == m - - -def test_promote_from_model_schema(): - m = _literal_models.Schema( - "s3://some/place/", - _type_models.SchemaType( - [ - _type_models.SchemaType.SchemaColumn( - "a", _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN - ), - _type_models.SchemaType.SchemaColumn( - "b", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME - ), - _type_models.SchemaType.SchemaColumn( - "c", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION - ), - _type_models.SchemaType.SchemaColumn("d", _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _type_models.SchemaType.SchemaColumn( - "e", _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER - ), - _type_models.SchemaType.SchemaColumn("f", _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING), - ] - ), - ) - - s = _schema_impl.Schema.promote_from_model(m) - assert s.uri == "s3://some/place/" - assert s.type.sdk_columns["a"].to_flyte_literal_type() == _primitives.Boolean.to_flyte_literal_type() - assert s.type.sdk_columns["b"].to_flyte_literal_type() == _primitives.Datetime.to_flyte_literal_type() - assert s.type.sdk_columns["c"].to_flyte_literal_type() == _primitives.Timedelta.to_flyte_literal_type() - assert s.type.sdk_columns["d"].to_flyte_literal_type() == _primitives.Float.to_flyte_literal_type() - assert s.type.sdk_columns["e"].to_flyte_literal_type() == _primitives.Integer.to_flyte_literal_type() - assert s.type.sdk_columns["f"].to_flyte_literal_type() == _primitives.String.to_flyte_literal_type() - assert s == m - - -def test_create_at_known_location(): - with _test_utils.LocalTestFileSystem(): - with _utils.AutoDeletingTempDir("test") as wd: - b = _schema_impl.Schema.create_at_known_location(wd.name, schema_type=_schema_impl.SchemaType()) - assert b.local_path is None - assert b.remote_location == wd.name + "/" - assert b.mode == "wb" - - with b as w: - w.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) - - df = _pd.read_parquet(_os.path.join(wd.name, "000000")) - assert list(df["a"]) == [1, 2, 3, 4] - assert list(df["b"]) == [5, 6, 7, 8] - - -def test_generic_schema_read(): - with _test_utils.LocalTestFileSystem(): - a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]) - ) - with a as writer: - writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) - - b = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) - with b as reader: - df = reader.read() - assert df.columns.values.tolist() == ["a", "b"] - assert df["a"].tolist() == [1, 2, 3, 4] - assert df["b"].tolist() == [5, 6, 7, 8] - - -def test_extra_schema_read(): - with _test_utils.LocalTestFileSystem(): - a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]) - ) - with a as writer: - writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) - - b = _schema_impl.Schema.fetch( - a.remote_prefix, - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer)]), - ) - with b as reader: - df = reader.read(concat=True, truncate_extra_columns=False) - assert df.columns.values.tolist() == ["a", "b"] - assert df["a"].tolist() == [1, 2, 3, 4] - assert df["b"].tolist() == [5, 6, 7, 8] - - with b as reader: - df = reader.read(concat=True) - assert df.columns.values.tolist() == ["a"] - assert df["a"].tolist() == [1, 2, 3, 4] - - -def test_normal_schema_read_with_fastparquet(): - with _test_utils.LocalTestFileSystem(): - a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Boolean)]) - ) - with a as writer: - writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [False, True, True, False]})) - - import os as _os - - original_engine = _os.getenv("PARQUET_ENGINE") - _os.environ["PARQUET_ENGINE"] = "fastparquet" - - b = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) - - with b as reader: - df = reader.read() - assert df["a"].tolist() == [1, 2, 3, 4] - assert _pd.api.types.is_bool_dtype(df.dtypes["b"]) - assert df["b"].tolist() == [False, True, True, False] - - if original_engine is None: - del _os.environ["PARQUET_ENGINE"] - else: - _os.environ["PARQUET_ENGINE"] = original_engine - - -def test_schema_read_consistency_between_two_engines(): - with _test_utils.LocalTestFileSystem(): - a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Boolean)]) - ) - with a as writer: - writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [True, True, True, False]})) - - import os as _os - - original_engine = _os.getenv("PARQUET_ENGINE") - _os.environ["PARQUET_ENGINE"] = "fastparquet" - - b = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) - - with b as b_reader: - b_df = b_reader.read() - _os.environ["PARQUET_ENGINE"] = "pyarrow" - - c = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) - with c as c_reader: - c_df = c_reader.read() - assert b_df.equals(c_df) - - if original_engine is None: - del _os.environ["PARQUET_ENGINE"] - else: - _os.environ["PARQUET_ENGINE"] = original_engine diff --git a/tests/flytekit/unit/common_tests/types/test_blobs.py b/tests/flytekit/unit/common_tests/types/test_blobs.py deleted file mode 100644 index 3057b11c6a..0000000000 --- a/tests/flytekit/unit/common_tests/types/test_blobs.py +++ /dev/null @@ -1,141 +0,0 @@ -from flytekit.common.types import blobs -from flytekit.common.types.impl import blobs as blob_impl -from flytekit.models import literals as _literal_models -from flytekit.models.core import types as _core_types -from flytekit.sdk import test_utils - - -def test_blob_instantiator(): - b = blobs.BlobInstantiator.create_at_known_location("abc") - assert isinstance(b, blob_impl.Blob) - assert b.remote_location == "abc" - assert b.mode == "wb" - assert b.metadata.type.format == "" - - -def test_blob(): - with test_utils.LocalTestFileSystem() as t: - b = blobs.Blob() - assert isinstance(b, blob_impl.Blob) - assert b.remote_location.startswith(t.name) - assert b.mode == "wb" - assert b.metadata.type.format == "" - - b2 = blobs.Blob(b) - assert isinstance(b2, blobs.Blob) - assert b2.scalar.blob.uri == b.remote_location - assert b2.scalar.blob.metadata == b.metadata - - b3 = blobs.Blob.from_string("/a/b/c") - assert isinstance(b3, blobs.Blob) - assert b3.scalar.blob.uri == "/a/b/c" - assert b3.scalar.blob.metadata.type.format == "" - - -def test_blob_promote_from_model(): - m = _literal_models.Literal( - scalar=_literal_models.Scalar( - blob=_literal_models.Blob( - _literal_models.BlobMetadata( - _core_types.BlobType( - format="f", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, - ) - ), - "some/path", - ) - ) - ) - b = blobs.Blob.promote_from_model(m) - assert b.value.blob.uri == "some/path" - assert b.value.blob.metadata.type.format == "f" - assert b.value.blob.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - - -def test_blob_to_python_std(): - impl = blob_impl.Blob("some/path", format="something") - b = blobs.Blob(impl).to_python_std() - assert b.metadata.type.format == "something" - assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE - assert b.uri == "some/path" - - -def test_csv_instantiator(): - b = blobs.CsvInstantiator.create_at_known_location("abc") - assert isinstance(b, blob_impl.Blob) - assert b.remote_location == "abc" - assert b.mode == "w" - assert b.metadata.type.format == "csv" - - -def test_csv(): - with test_utils.LocalTestFileSystem() as t: - b = blobs.CSV() - assert isinstance(b, blob_impl.Blob) - assert b.remote_location.startswith(t.name) - assert b.mode == "w" - assert b.metadata.type.format == "csv" - - b2 = blobs.CSV(b) - assert isinstance(b2, blobs.Blob) - assert b2.scalar.blob.uri == b.remote_location - assert b2.scalar.blob.metadata == b.metadata - - b3 = blobs.CSV.from_string("/a/b/c") - assert isinstance(b3, blobs.Blob) - assert b3.scalar.blob.uri == "/a/b/c" - assert b3.scalar.blob.metadata.type.format == "csv" - - -def test_multipartblob_instantiator(): - b = blobs.MultiPartBlob.create_at_known_location("abc") - assert isinstance(b, blob_impl.MultiPartBlob) - assert b.remote_location == "abc" + "/" - assert b.mode == "wb" - assert b.metadata.type.format == "" - - -def test_multipartblob(): - with test_utils.LocalTestFileSystem() as t: - b = blobs.MultiPartBlob() - assert isinstance(b, blob_impl.MultiPartBlob) - assert b.remote_location.startswith(t.name) - assert b.mode == "wb" - assert b.metadata.type.format == "" - - b2 = blobs.MultiPartBlob(b) - assert isinstance(b2, blobs.MultiPartBlob) - assert b2.scalar.blob.uri == b.remote_location - assert b2.scalar.blob.metadata == b.metadata - - b3 = blobs.MultiPartBlob.from_string("/a/b/c") - assert isinstance(b3, blobs.MultiPartBlob) - assert b3.scalar.blob.uri == "/a/b/c/" - assert b3.scalar.blob.metadata.type.format == "" - - -def test_multipartcsv_instantiator(): - b = blobs.MultiPartCsvInstantiator.create_at_known_location("abc") - assert isinstance(b, blob_impl.MultiPartBlob) - assert b.remote_location == "abc" + "/" - assert b.mode == "w" - assert b.metadata.type.format == "csv" - - -def test_multipartcsv(): - with test_utils.LocalTestFileSystem() as t: - b = blobs.MultiPartCSV() - assert isinstance(b, blob_impl.MultiPartBlob) - assert b.remote_location.startswith(t.name) - assert b.mode == "w" - assert b.metadata.type.format == "csv" - - b2 = blobs.MultiPartCSV(b) - assert isinstance(b2, blobs.MultiPartCSV) - assert b2.scalar.blob.uri == b.remote_location - assert b2.scalar.blob.metadata == b.metadata - - b3 = blobs.MultiPartCSV.from_string("/a/b/c") - assert isinstance(b3, blobs.MultiPartCSV) - assert b3.scalar.blob.uri == "/a/b/c/" - assert b3.scalar.blob.metadata.type.format == "csv" diff --git a/tests/flytekit/unit/common_tests/types/test_containers.py b/tests/flytekit/unit/common_tests/types/test_containers.py deleted file mode 100644 index dcb8730b08..0000000000 --- a/tests/flytekit/unit/common_tests/types/test_containers.py +++ /dev/null @@ -1,173 +0,0 @@ -import pytest -from six.moves import range as _range - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import containers, primitives -from flytekit.models import literals -from flytekit.models import types as literal_types - - -def test_list(): - list_type = containers.List(primitives.Integer) - assert list_type.to_flyte_literal_type().simple is None - assert list_type.to_flyte_literal_type().map_value_type is None - assert list_type.to_flyte_literal_type().schema is None - assert list_type.to_flyte_literal_type().collection_type.simple == literal_types.SimpleType.INTEGER - - list_value = list_type.from_python_std([1, 2, 3, 4]) - assert list_value.to_python_std() == [1, 2, 3, 4] - assert list_type.from_flyte_idl(list_value.to_flyte_idl()) == list_value - - assert list_value.collection.literals[0].scalar.primitive.integer == 1 - assert list_value.collection.literals[1].scalar.primitive.integer == 2 - assert list_value.collection.literals[2].scalar.primitive.integer == 3 - assert list_value.collection.literals[3].scalar.primitive.integer == 4 - - obj2 = list_type.from_string("[1, 2, 3,4]") - assert obj2 == list_value - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_python_std(["a", "b", "c", "d"]) - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_python_std([1, 2, 3, "abc"]) - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_python_std(1) - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_python_std([[1]]) - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string('["fdsa"]') - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string("[1, 2, 3, []]") - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string("'[\"not list json\"]'") - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string('["unclosed","list"') - - -def test_string_list(): - list_type = containers.List(primitives.String) - obj = list_type.from_string('["fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""]') - assert len(obj.collection.literals) == 5 - assert obj.to_python_std() == ["fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""] - - # Test that two classes of the same type are comparable - list_type_two = containers.List(primitives.String) - obj2 = list_type_two.from_string('["fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""]') - assert obj == obj2 - - -def test_empty_parsing(): - list_type = containers.List(primitives.String) - obj = list_type.from_string("[]") - assert len(obj) == 0 - - # The String primitive type does not allow lists or maps to be converted - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string('["fdjs", []]') - - with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string('["fdjs", {}]') - - -def test_nested_list(): - list_type = containers.List(containers.List(primitives.Integer)) - - assert list_type.to_flyte_literal_type().simple is None - assert list_type.to_flyte_literal_type().map_value_type is None - assert list_type.to_flyte_literal_type().schema is None - assert list_type.to_flyte_literal_type().collection_type.simple is None - assert list_type.to_flyte_literal_type().collection_type.map_value_type is None - assert list_type.to_flyte_literal_type().collection_type.schema is None - assert list_type.to_flyte_literal_type().collection_type.collection_type.simple == literal_types.SimpleType.INTEGER - - gt = [[1, 2, 3], [4, 5, 6], []] - list_value = list_type.from_python_std(gt) - assert list_value.to_python_std() == gt - assert list_type.from_flyte_idl(list_value.to_flyte_idl()) == list_value - - assert list_value.collection.literals[0].collection.literals[0].scalar.primitive.integer == 1 - assert list_value.collection.literals[0].collection.literals[1].scalar.primitive.integer == 2 - assert list_value.collection.literals[0].collection.literals[2].scalar.primitive.integer == 3 - - assert list_value.collection.literals[1].collection.literals[0].scalar.primitive.integer == 4 - assert list_value.collection.literals[1].collection.literals[1].scalar.primitive.integer == 5 - assert list_value.collection.literals[1].collection.literals[2].scalar.primitive.integer == 6 - - assert len(list_value.collection.literals[2].collection.literals) == 0 - - obj = list_type.from_string("[[1, 2, 3], [4, 5, 6]]") - assert len(obj) == 2 - assert len(obj.collection.literals[0]) == 3 - - -def test_reprs(): - list_type = containers.List(primitives.Integer) - obj = list_type.from_python_std(list(_range(3))) - assert obj.short_string() == "List(len=3, [Integer(0), Integer(1), Integer(2)])" - assert ( - obj.verbose_string() == "List(\n" - "\tlen=3,\n" - "\t[\n" - "\t\tInteger(0),\n" - "\t\tInteger(1),\n" - "\t\tInteger(2)\n" - "\t]\n" - ")" - ) - - nested_list_type = containers.List(containers.List(primitives.Integer)) - nested_obj = nested_list_type.from_python_std([list(_range(3)), list(_range(3))]) - - assert ( - nested_obj.short_string() - == "List>(len=2, [List(len=3, [Integer(0), Integer(1), Integer(2)]), " - "List(len=3, [Integer(0), Integer(1), Integer(2)])])" - ) - assert ( - nested_obj.verbose_string() == "List>(\n" - "\tlen=2,\n" - "\t[\n" - "\t\tList(\n" - "\t\t\tlen=3,\n" - "\t\t\t[\n" - "\t\t\t\tInteger(0),\n" - "\t\t\t\tInteger(1),\n" - "\t\t\t\tInteger(2)\n" - "\t\t\t]\n" - "\t\t),\n" - "\t\tList(\n" - "\t\t\tlen=3,\n" - "\t\t\t[\n" - "\t\t\t\tInteger(0),\n" - "\t\t\t\tInteger(1),\n" - "\t\t\t\tInteger(2)\n" - "\t\t\t]\n" - "\t\t)\n" - "\t]\n" - ")" - ) - - -def test_model_promotion(): - list_type = containers.List(primitives.Integer) - list_model = literals.Literal( - collection=literals.LiteralCollection( - literals=[ - literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=0))), - literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), - literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2))), - ] - ) - ) - list_obj = list_type.promote_from_model(list_model) - assert len(list_obj.collection.literals) == 3 - assert isinstance(list_obj.collection.literals[0], primitives.Integer) - assert list_obj == list_type.from_python_std([0, 1, 2]) - assert list_obj == list_type([primitives.Integer(0), primitives.Integer(1), primitives.Integer(2)]) diff --git a/tests/flytekit/unit/common_tests/types/test_helpers.py b/tests/flytekit/unit/common_tests/types/test_helpers.py deleted file mode 100644 index dd8b45af23..0000000000 --- a/tests/flytekit/unit/common_tests/types/test_helpers.py +++ /dev/null @@ -1,59 +0,0 @@ -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.common.types import helpers as _type_helpers -from flytekit.models import literals as _literals -from flytekit.models import types as _model_types -from flytekit.sdk import types as _sdk_types - - -def test_python_std_to_sdk_type(): - o = _type_helpers.python_std_to_sdk_type(_sdk_types.Types.Integer) - assert o.to_flyte_literal_type().simple == _model_types.SimpleType.INTEGER - - o = _type_helpers.python_std_to_sdk_type([_sdk_types.Types.Boolean]) - assert o.to_flyte_literal_type().collection_type.simple == _model_types.SimpleType.BOOLEAN - - -def test_get_sdk_type_from_literal_type(): - o = _type_helpers.get_sdk_type_from_literal_type(_model_types.LiteralType(simple=_model_types.SimpleType.FLOAT)) - assert o == _sdk_types.Types.Float - - -def test_infer_sdk_type_from_literal(): - o = _type_helpers.infer_sdk_type_from_literal( - _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(string_value="abc"))) - ) - assert o == _sdk_types.Types.String - - o = _type_helpers.infer_sdk_type_from_literal( - _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())) - ) - assert o is _base_sdk_types.Void - - -def test_get_sdk_value_from_literal(): - o = _type_helpers.get_sdk_value_from_literal(_literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void()))) - assert o.to_python_std() is None - - o = _type_helpers.get_sdk_value_from_literal( - _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), - sdk_type=_sdk_types.Types.Integer, - ) - assert o.to_python_std() is None - - o = _type_helpers.get_sdk_value_from_literal( - _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=1))), - sdk_type=_sdk_types.Types.Integer, - ) - assert o.to_python_std() == 1 - - o = _type_helpers.get_sdk_value_from_literal( - _literals.Literal( - collection=_literals.LiteralCollection( - [ - _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=1))), - _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), - ] - ) - ) - ) - assert o.to_python_std() == [1, None] diff --git a/tests/flytekit/unit/common_tests/types/test_primitives.py b/tests/flytekit/unit/common_tests/types/test_primitives.py deleted file mode 100644 index b161c751dd..0000000000 --- a/tests/flytekit/unit/common_tests/types/test_primitives.py +++ /dev/null @@ -1,289 +0,0 @@ -import datetime - -import pytest -from dateutil import tz - -from flytekit.common.exceptions import user as user_exceptions -from flytekit.common.types import base_sdk_types, primitives -from flytekit.models import types as literal_types - - -def test_integer(): - # Check type specification - assert primitives.Integer.to_flyte_literal_type().simple == literal_types.SimpleType.INTEGER - - # Test value behavior - obj = primitives.Integer.from_python_std(1) - assert obj.to_python_std() == 1 - assert primitives.Integer.from_flyte_idl(obj.to_flyte_idl()) == obj - - for val in [ - 1.0, - "abc", - True, - False, - datetime.datetime.now(), - datetime.timedelta(seconds=1), - ]: - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Integer.from_python_std(val) - - obj = primitives.Integer.from_python_std(None) - assert obj.to_python_std() is None - assert primitives.Integer.from_flyte_idl(obj.to_flyte_idl()) == obj - - # Test string parsing - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Integer.from_string("books") - obj = primitives.Integer.from_string("299792458") - assert obj.to_python_std() == 299792458 - assert primitives.Integer.from_flyte_idl(obj.to_flyte_idl()) == obj - - assert obj.short_string() == "Integer(299792458)" - assert obj.verbose_string() == "Integer(299792458)" - - -def test_float(): - # Check type specification - assert primitives.Float.to_flyte_literal_type().simple == literal_types.SimpleType.FLOAT - - # Test value behavior - obj = primitives.Float.from_python_std(1.0) - assert obj.to_python_std() == 1.0 - assert primitives.Float.from_flyte_idl(obj.to_flyte_idl()) == obj - - for val in [ - 1, - "abc", - True, - False, - datetime.datetime.now(), - datetime.timedelta(seconds=1), - ]: - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Float.from_python_std(val) - - obj = primitives.Float.from_python_std(None) - assert obj.to_python_std() is None - assert primitives.Float.from_flyte_idl(obj.to_flyte_idl()) == obj - - # Test string parsing - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Float.from_string("lightning") - obj = primitives.Float.from_string("2.71828") - assert obj.to_python_std() == 2.71828 - assert primitives.Float.from_flyte_idl(obj.to_flyte_idl()) == obj - - assert obj.short_string() == "Float(2.71828)" - assert obj.verbose_string() == "Float(2.71828)" - - -def test_boolean(): - # Check type specification - assert primitives.Boolean.to_flyte_literal_type().simple == literal_types.SimpleType.BOOLEAN - - # Test value behavior - obj = primitives.Boolean.from_python_std(True) - assert obj.to_python_std() is True - assert primitives.Boolean.from_flyte_idl(obj.to_flyte_idl()) == obj - - for val in [1, 1.0, "abc", datetime.datetime.now(), datetime.timedelta(seconds=1)]: - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Boolean.from_python_std(val) - - obj = primitives.Boolean.from_python_std(None) - assert obj.to_python_std() is None - assert primitives.Boolean.from_flyte_idl(obj.to_flyte_idl()) == obj - - # Test string parsing - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Boolean.from_string("lightning") - obj = primitives.Boolean.from_string("false") - assert not obj.to_python_std() - assert primitives.Boolean.from_flyte_idl(obj.to_flyte_idl()) == obj - obj = primitives.Boolean.from_string("False") - assert not obj.to_python_std() - obj = primitives.Boolean.from_string("0") - assert not obj.to_python_std() - obj = primitives.Boolean.from_string("true") - assert obj.to_python_std() - obj = primitives.Boolean.from_string("True") - assert obj.to_python_std() - obj = primitives.Boolean.from_string("1") - assert obj.to_python_std() - assert primitives.Boolean.from_flyte_idl(obj.to_flyte_idl()) == obj - - assert obj.short_string() == "Boolean(True)" - assert obj.verbose_string() == "Boolean(True)" - - -def test_string(): - # Check type specification - assert primitives.String.to_flyte_literal_type().simple == literal_types.SimpleType.STRING - - # Test value behavior - obj = primitives.String.from_python_std("abc") - assert obj.to_python_std() == "abc" - assert primitives.String.from_flyte_idl(obj.to_flyte_idl()) == obj - - for val in [ - 1, - 1.0, - True, - False, - datetime.datetime.now(), - datetime.timedelta(seconds=1), - ]: - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.String.from_python_std(val) - - obj = primitives.String.from_python_std(None) - assert obj.to_python_std() is None - assert primitives.String.from_flyte_idl(obj.to_flyte_idl()) == obj - - # Test string parsing - my_string = "this is a string" - obj = primitives.String.from_string(my_string) - assert obj.to_python_std() == my_string - assert primitives.String.from_flyte_idl(obj.to_flyte_idl()) == obj - - assert obj.short_string() == "String('this is a string')" - assert obj.verbose_string() == "String('this is a string')" - - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.String.from_string([]) - - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.String.from_string({}) - - -class UTC(datetime.tzinfo): - """UTC""" - - def utcoffset(self, dt): - return datetime.timedelta(0) - - def tzname(self, dt): - return "UTC" - - def dst(self, dt): - return datetime.timedelta(0) - - -def test_datetime(): - # Check type specification - assert primitives.Datetime.to_flyte_literal_type().simple == literal_types.SimpleType.DATETIME - - # Test value behavior - dt = datetime.datetime.now(tz=tz.UTC) - obj = primitives.Datetime.from_python_std(dt) - assert primitives.Datetime.from_flyte_idl(obj.to_flyte_idl()) == obj - assert obj.to_python_std() == dt - - # Timezone is required - with pytest.raises(user_exceptions.FlyteValueException): - primitives.Datetime.from_python_std(datetime.datetime.now()) - - for val in [1, 1.0, "abc", True, False, datetime.timedelta(seconds=1)]: - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Datetime.from_python_std(val) - - obj = primitives.Datetime.from_python_std(None) - assert obj.to_python_std() is None - assert primitives.Datetime.from_flyte_idl(obj.to_flyte_idl()) == obj - - # Test string parsing - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Datetime.from_string("not a real date") - obj = primitives.Datetime.from_string("2018-05-15 4:32pm UTC") - test_dt = datetime.datetime(2018, 5, 15, 16, 32, 0, 0, UTC()) - assert obj.short_string() == "Datetime(2018-05-15 16:32:00+00:00)" - assert obj.verbose_string() == "Datetime(2018-05-15 16:32:00+00:00)" - assert obj.to_python_std() == test_dt - assert primitives.Datetime.from_flyte_idl(obj.to_flyte_idl()) == obj - - -def test_timedelta(): - # Check type specification - assert primitives.Timedelta.to_flyte_literal_type().simple == literal_types.SimpleType.DURATION - - # Test value behavior - obj = primitives.Timedelta.from_python_std(datetime.timedelta(seconds=1)) - assert obj.to_python_std() == datetime.timedelta(seconds=1) - assert primitives.Timedelta.from_flyte_idl(obj.to_flyte_idl()) == obj - - for val in [1.0, "abc", True, False, datetime.datetime.now()]: - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Timedelta.from_python_std(val) - - obj = primitives.Timedelta.from_python_std(None) - assert obj.to_python_std() is None - assert primitives.Timedelta.from_flyte_idl(obj.to_flyte_idl()) == obj - - # Test string parsing - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Timedelta.from_string("not a real duration") - obj = primitives.Timedelta.from_string("15 hours, 1.1 second") - test_d = datetime.timedelta(hours=15, seconds=1, milliseconds=100) - assert obj.short_string() == "Timedelta(15:00:01.100000)" - assert obj.verbose_string() == "Timedelta(15:00:01.100000)" - assert obj.to_python_std() == test_d - assert primitives.Timedelta.from_flyte_idl(obj.to_flyte_idl()) == obj - - -def test_void(): - # Check type specification - with pytest.raises(user_exceptions.FlyteAssertion): - base_sdk_types.Void.to_flyte_literal_type() - - # Test value behavior - for val in [ - 1, - 1.0, - "abc", - True, - False, - datetime.datetime.now(), - datetime.timedelta(seconds=1), - None, - ]: - assert base_sdk_types.Void.from_python_std(val).to_python_std() is None - - obj = base_sdk_types.Void() - assert base_sdk_types.Void.from_flyte_idl(obj.to_flyte_idl()) == obj - - assert obj.short_string() == "Void()" - assert obj.verbose_string() == "Void()" - - -def test_generic(): - # Check type specification - assert primitives.Generic.to_flyte_literal_type().simple == literal_types.SimpleType.STRUCT - - # Test value behavior - d = {"a": [1, 2, 3], "b": "abc", "c": 1, "d": {"a": 1}} - obj = primitives.Generic.from_python_std(d) - assert obj.to_python_std() == d - assert primitives.Generic.from_flyte_idl(obj.to_flyte_idl()) == obj - - for val in [ - 1.0, - "abc", - True, - False, - datetime.datetime.now(), - datetime.timedelta(seconds=1), - ]: - with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Generic.from_python_std(val) - - obj = primitives.Generic.from_python_std(None) - assert obj.to_python_std() is None - assert primitives.Generic.from_flyte_idl(obj.to_flyte_idl()) == obj - - # Test string parsing - with pytest.raises(user_exceptions.FlyteValueException): - primitives.Generic.from_string("1") - obj = primitives.Generic.from_string('{"a": 1.0}') - assert obj.to_python_std() == {"a": 1.0} - assert primitives.Generic.from_flyte_idl(obj.to_flyte_idl()) == obj diff --git a/tests/flytekit/unit/common_tests/types/test_proto.py b/tests/flytekit/unit/common_tests/types/test_proto.py deleted file mode 100644 index 074fb1dfd8..0000000000 --- a/tests/flytekit/unit/common_tests/types/test_proto.py +++ /dev/null @@ -1,63 +0,0 @@ -import base64 as _base64 - -import pytest as _pytest -from flyteidl.core import errors_pb2 as _errors_pb2 - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import proto as _proto -from flytekit.common.types.proto import ProtobufType -from flytekit.models import types as _type_models - - -def test_wrong_type(): - with _pytest.raises(_user_exceptions.FlyteTypeException): - _proto.create_protobuf(int) - - -def test_proto_to_literal_type(): - proto_type = _proto.create_protobuf(_errors_pb2.ContainerError) - assert proto_type.to_flyte_literal_type().simple == _type_models.SimpleType.BINARY - assert len(proto_type.to_flyte_literal_type().metadata) == 1 - assert ( - proto_type.to_flyte_literal_type().metadata[_proto.Protobuf.PB_FIELD_KEY] - == "flyteidl.core.errors_pb2.ContainerError" - ) - - -def test_proto(): - proto_type = _proto.create_protobuf(_errors_pb2.ContainerError) - assert proto_type.short_class_string() == "Types.Proto(flyteidl.core.errors_pb2.ContainerError)" - run_test_proto_type(proto_type) - - -def test_generic_proto(): - proto_type = _proto.create_generic(_errors_pb2.ContainerError) - assert proto_type.short_class_string() == "Types.GenericProto(flyteidl.core.errors_pb2.ContainerError)" - run_test_proto_type(proto_type) - - -def run_test_proto_type(proto_type: ProtobufType): - pb = _errors_pb2.ContainerError(code="code", message="message") - obj = proto_type.from_python_std(pb) - obj2 = proto_type.from_flyte_idl(obj.to_flyte_idl()) - assert obj == obj2 - - obj = obj.to_python_std() - obj2 = obj2.to_python_std() - - assert obj.code == "code" - assert obj.message == "message" - - assert obj2.code == "code" - assert obj2.message == "message" - - -def test_from_string(): - proto_type = _proto.create_protobuf(_errors_pb2.ContainerError) - - pb = _errors_pb2.ContainerError(code="code", message="message") - pb_str = _base64.b64encode(pb.SerializeToString()) - - obj = proto_type.from_string(pb_str) - assert obj.to_python_std().code == "code" - assert obj.to_python_std().message == "message" diff --git a/tests/flytekit/unit/common_tests/types/test_schema.py b/tests/flytekit/unit/common_tests/types/test_schema.py deleted file mode 100644 index 02bfb8f55e..0000000000 --- a/tests/flytekit/unit/common_tests/types/test_schema.py +++ /dev/null @@ -1,69 +0,0 @@ -from flytekit.common.types import primitives, schema -from flytekit.common.types.impl import schema as schema_impl -from flytekit.sdk import test_utils - -_ALL_COLUMN_TYPES = [ - ("a", primitives.Integer), - ("b", primitives.String), - ("c", primitives.Float), - ("d", primitives.Datetime), - ("e", primitives.Timedelta), - ("f", primitives.Boolean), -] - - -def test_generic_schema_instantiator(): - instantiator = schema.schema_instantiator() - b = instantiator.create_at_known_location("abc") - assert isinstance(b, schema_impl.Schema) - assert b.remote_location == "abc/" - assert b.mode == "wb" - assert len(b.type.columns) == 0 - - -def test_typed_schema_instantiator(): - instantiator = schema.schema_instantiator(_ALL_COLUMN_TYPES) - b = instantiator.create_at_known_location("abc") - assert isinstance(b, schema_impl.Schema) - assert b.remote_location == "abc/" - assert b.mode == "wb" - assert len(b.type.columns) == len(_ALL_COLUMN_TYPES) - assert list(b.type.sdk_columns.items()) == _ALL_COLUMN_TYPES - - -def test_generic_schema(): - with test_utils.LocalTestFileSystem() as t: - instantiator = schema.schema_instantiator() - b = instantiator() - assert isinstance(b, schema_impl.Schema) - assert b.mode == "wb" - assert len(b.type.columns) == 0 - assert b.remote_location.startswith(t.name) - - -def test_typed_schema(): - with test_utils.LocalTestFileSystem() as t: - instantiator = schema.schema_instantiator(_ALL_COLUMN_TYPES) - b = instantiator() - assert isinstance(b, schema_impl.Schema) - assert b.mode == "wb" - assert len(b.type.columns) == len(_ALL_COLUMN_TYPES) - assert list(b.type.sdk_columns.items()) == _ALL_COLUMN_TYPES - assert b.remote_location.startswith(t.name) - - -# Ensures that subclassing types works inside a schema. -def test_casting(): - class MyDateTime(primitives.Datetime): - ... - - with test_utils.LocalTestFileSystem(): - test_columns_1 = [("altered", MyDateTime)] - test_columns_2 = [("altered", primitives.Datetime)] - - instantiator_1 = schema.schema_instantiator(test_columns_1) - a = instantiator_1() - - instantiator_2 = schema.schema_instantiator(test_columns_2) - - a.cast_to(instantiator_2._schema_type) diff --git a/tests/flytekit/unit/configuration/test_waterfall.py b/tests/flytekit/unit/configuration/test_waterfall.py index 2b35d641c0..a4cd749a80 100644 --- a/tests/flytekit/unit/configuration/test_waterfall.py +++ b/tests/flytekit/unit/configuration/test_waterfall.py @@ -1,7 +1,7 @@ import os as _os -from flytekit.common.utils import AutoDeletingTempDir as _AutoDeletingTempDir from flytekit.configuration import common as _common +from flytekit.core.utils import AutoDeletingTempDir as _AutoDeletingTempDir def test_lookup_waterfall_raw_env_var(): diff --git a/tests/flytekit/unit/contrib/__init__.py b/tests/flytekit/unit/contrib/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/contrib/sensors/__init__.py b/tests/flytekit/unit/contrib/sensors/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/contrib/sensors/test_impl.py b/tests/flytekit/unit/contrib/sensors/test_impl.py deleted file mode 100644 index 8382dae762..0000000000 --- a/tests/flytekit/unit/contrib/sensors/test_impl.py +++ /dev/null @@ -1,58 +0,0 @@ -import mock -from hmsclient import HMSClient -from hmsclient.genthrift.hive_metastore import ttypes as _ttypes - -from flytekit.contrib.sensors.impl import HiveFilteredPartitionSensor, HiveNamedPartitionSensor, HiveTableSensor - - -def test_HiveTableSensor(): - hive_table_sensor = HiveTableSensor(table_name="mocked_table", host="localhost", port=1234) - assert hive_table_sensor._schema == "default" - with mock.patch.object(HMSClient, "open"): - with mock.patch.object(HMSClient, "get_table"): - success, interval = hive_table_sensor._do_poll() - assert success - assert interval is None - - with mock.patch.object(HMSClient, "get_table", side_effect=_ttypes.NoSuchObjectException()): - success, interval = hive_table_sensor._do_poll() - assert not success - assert interval is None - - -def test_HiveNamedPartitionSensor(): - hive_named_partition_sensor = HiveNamedPartitionSensor( - table_name="mocked_table", partition_names=["ds=2019-10-10", "ds=2019-10-11"], host="localhost", port=1234 - ) - assert hive_named_partition_sensor._schema == "default" - with mock.patch.object(HMSClient, "open"): - with mock.patch.object(HMSClient, "get_partition_by_name"): - success, interval = hive_named_partition_sensor._do_poll() - assert success - assert interval is None - - with mock.patch.object( - HMSClient, - "get_partition_by_name", - side_effect=_ttypes.NoSuchObjectException(), - ): - success, interval = hive_named_partition_sensor._do_poll() - assert not success - assert interval is None - - -def test_HiveFilteredPartitionSensor(): - hive_filtered_partition_sensor = HiveFilteredPartitionSensor( - table_name="mocked_table", partition_filter="ds = '2019-10-10' AND region = 'NYC'", host="localhost", port=1234 - ) - assert hive_filtered_partition_sensor._schema == "default" - with mock.patch.object(HMSClient, "open"): - with mock.patch.object(HMSClient, "get_partitions_by_filter", return_value=["any"]): - success, interval = hive_filtered_partition_sensor._do_poll() - assert success - assert interval is None - - with mock.patch.object(HMSClient, "get_partitions_by_filter", return_value=[]): - success, interval = hive_filtered_partition_sensor._do_poll() - assert not success - assert interval is None diff --git a/tests/flytekit/unit/contrib/sensors/test_task.py b/tests/flytekit/unit/contrib/sensors/test_task.py deleted file mode 100644 index 0956b2aada..0000000000 --- a/tests/flytekit/unit/contrib/sensors/test_task.py +++ /dev/null @@ -1,22 +0,0 @@ -from flytekit.contrib.sensors.base_sensor import Sensor as _Sensor -from flytekit.contrib.sensors.task import sensor_task - - -class MyMockSensor(_Sensor): - def __init__(self, **kwargs): - super(MyMockSensor, self).__init__(**kwargs) - - def _do_poll(self): - """ - :rtype: (bool, Optional[datetime.timedelta]) - """ - return True, None - - -def test_sensor_works(): - @sensor_task - def my_test_task(wf_params): - return MyMockSensor() - - out = my_test_task.unit_test() - assert len(out) == 0 diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index b490a4f9a0..13cfcf3706 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -5,11 +5,11 @@ import pytest from flytekit import task, workflow -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.condition import conditional from flytekit.core.context_manager import Image, ImageConfig, SerializationSettings from flytekit.models.core.workflow import Node +from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index db7e8c3d10..93a4d0e039 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -7,7 +7,6 @@ import pytest -from flytekit.common.exceptions.user import FlyteAssertion from flytekit.core import context_manager from flytekit.core.context_manager import ExecutionState, FlyteContextManager, Image, ImageConfig from flytekit.core.data_persistence import FileAccessProvider @@ -15,6 +14,7 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteAssertion from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index 874eea27a5..c6134558b4 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -1,13 +1,13 @@ from collections import OrderedDict from typing import Dict, List -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.task import task from flytekit.models.core.types import BlobType from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer default_img = Image(name="default", fqn="test", tag="tag") diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index 6b99c93368..f120398306 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -4,16 +4,16 @@ import pandas as pd import pytest -from flytekit.common.exceptions.user import FlyteValidationException -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_task import kwtypes from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import reference_task, task from flytekit.core.workflow import ImperativeWorkflow, get_promise, workflow +from flytekit.exceptions.user import FlyteValidationException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.models import literals as literal_models +from flytekit.tools.translator import get_serializable from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDatasetType diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index 72d10fb2d1..baa33cd356 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -4,7 +4,6 @@ import pytest from flyteidl.admin import launch_plan_pb2 as _launch_plan_idl -from flytekit.common.translator import get_serializable from flytekit.core import context_manager, launch_plan, notification from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.schedule import CronSchedule @@ -13,6 +12,7 @@ from flytekit.models.common import Annotations, AuthRole, Labels, RawOutputDataConfig from flytekit.models.core import execution as _execution_model from flytekit.models.core import identifier as identifier_models +from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 31cbceffe9..31253ce1dd 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -4,12 +4,12 @@ import pytest from flytekit import LaunchPlan, map_task -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.map_task import MapPythonTask from flytekit.core.task import TaskMetadata, task from flytekit.core.workflow import workflow +from flytekit.tools.translator import get_serializable @task diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 4b1c59c479..811b2d46e5 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -5,16 +5,16 @@ import pytest from flytekit import Resources, map_task -from flytekit.common.exceptions.user import FlyteAssertion -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteAssertion from flytekit.models import literals as _literal_models from flytekit.models.task import Resources as _resources_models +from flytekit.tools.translator import get_serializable def test_normal_task(): diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index f4fa715562..5151539236 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -3,7 +3,6 @@ import pytest -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_task import kwtypes from flytekit.core.context_manager import Image, ImageConfig @@ -16,6 +15,7 @@ from flytekit.core.testing import patch, task_mock from flytekit.core.workflow import reference_workflow, workflow from flytekit.models.core import identifier as _identifier_model +from flytekit.tools.translator import get_serializable # This is used for docs diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py index 4fa33bbe34..ea44099587 100644 --- a/tests/flytekit/unit/core/test_resolver.py +++ b/tests/flytekit/unit/core/test_resolver.py @@ -3,7 +3,6 @@ import pytest -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_task import TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver @@ -11,6 +10,7 @@ from flytekit.core.python_auto_container import default_task_resolver from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index fac26994a1..d3395e9fd5 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -5,7 +5,6 @@ import pytest from flytekit import ContainerTask, kwtypes -from flytekit.common.translator import get_serializable from flytekit.configuration import set_flyte_config_file from flytekit.core import context_manager from flytekit.core.condition import conditional @@ -13,6 +12,7 @@ from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.models.types import SimpleType +from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 2de16e6c9b..64f4b6f379 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -18,7 +18,6 @@ from pandas._testing import assert_frame_equal from flytekit import kwtypes -from flytekit.common.exceptions import user as user_exceptions from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import ( DataclassTransformer, @@ -30,6 +29,7 @@ convert_json_schema_to_python_class, dataclass_from_dict, ) +from flytekit.exceptions import user as user_exceptions from flytekit.models import types as model_types from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 54c1c77b8c..4b92698022 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -17,7 +17,6 @@ import flytekit from flytekit import ContainerTask, Secret, SQLTask, dynamic, kwtypes, map_task -from flytekit.common.translator import get_serializable from flytekit.core import context_manager, launch_plan, promise from flytekit.core.condition import conditional from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig @@ -34,6 +33,7 @@ from flytekit.models.interface import Parameter from flytekit.models.task import Resources as _resource_models from flytekit.models.types import LiteralType, SimpleType +from flytekit.tools.translator import get_serializable from flytekit.types.directory import FlyteDirectory, TensorboardLogs from flytekit.types.file import FlyteFile, PNGImageFile from flytekit.types.schema import FlyteSchema, SchemaOpenMode diff --git a/tests/flytekit/unit/common_tests/test_utils.py b/tests/flytekit/unit/core/test_utils.py similarity index 91% rename from tests/flytekit/unit/common_tests/test_utils.py rename to tests/flytekit/unit/core/test_utils.py index a24d585d1c..112a864b30 100644 --- a/tests/flytekit/unit/common_tests/test_utils.py +++ b/tests/flytekit/unit/core/test_utils.py @@ -1,6 +1,6 @@ import pytest -from flytekit.common.utils import _dnsify +from flytekit.core.utils import _dnsify @pytest.mark.parametrize( diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 5054adc3b2..ef435393ef 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -6,13 +6,13 @@ from pandas.testing import assert_frame_equal from flytekit import StructuredDataset, kwtypes -from flytekit.common.exceptions.user import FlyteValidationException, FlyteValueException -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.condition import conditional from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.task import task from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow +from flytekit.exceptions.user import FlyteValidationException, FlyteValueException +from flytekit.tools.translator import get_serializable from flytekit.types.schema import FlyteSchema try: diff --git a/tests/flytekit/unit/engines/__init__.py b/tests/flytekit/unit/engines/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/engines/flyte/__init__.py b/tests/flytekit/unit/engines/flyte/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py deleted file mode 100644 index ba7d478cd0..0000000000 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ /dev/null @@ -1,807 +0,0 @@ -import os - -import pytest -from flyteidl.core import errors_pb2 -from mock import MagicMock, PropertyMock, patch - -from flytekit.common import constants, utils -from flytekit.common.exceptions import scopes -from flytekit.configuration import TemporaryConfiguration -from flytekit.engines.flyte import engine -from flytekit.models import common as _common_models -from flytekit.models import execution as _execution_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import literals -from flytekit.models import task as _task_models -from flytekit.models.admin import common as _common -from flytekit.models.core import errors, identifier -from flytekit.sdk import test_utils - -_INPUT_MAP = literals.LiteralMap( - {"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))} -) -_OUTPUT_MAP = literals.LiteralMap( - {"b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2)))} -) -_EMPTY_LITERAL_MAP = literals.LiteralMap(literals={}) - - -@pytest.fixture(scope="function", autouse=True) -def temp_config(): - with TemporaryConfiguration( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "../../../common/configs/local.config", - ), - internal_overrides={ - "image": "myflyteimage:{}".format(os.environ.get("IMAGE_VERSION", "sha")), - "project": "myflyteproject", - "domain": "development", - }, - ): - yield - - -@pytest.fixture(scope="function", autouse=True) -def execution_data_locations(): - with test_utils.LocalTestFileSystem() as fs: - input_filename = fs.get_named_tempfile("inputs.pb") - output_filename = fs.get_named_tempfile("outputs.pb") - utils.write_proto_to_file(_INPUT_MAP.to_flyte_idl(), input_filename) - utils.write_proto_to_file(_OUTPUT_MAP.to_flyte_idl(), output_filename) - yield ( - _common_models.UrlBlob(input_filename, 100), - _common_models.UrlBlob(output_filename, 100), - ) - - -@scopes.system_entry_point -def _raise_system_exception(*args, **kwargs): - raise ValueError("errorERRORerror") - - -@scopes.user_entry_point -def _raise_user_exception(*args, **kwargs): - raise ValueError("userUSERuser") - - -@scopes.system_entry_point -def test_task_system_failure(): - m = MagicMock() - m.execute = _raise_system_exception - - with utils.AutoDeletingTempDir("test") as tmp: - engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) - - doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file( - errors_pb2.ErrorDocument, - os.path.join(tmp.name, constants.ERROR_FILE_NAME), - ) - ) - assert doc.error.code == "SYSTEM:Unknown" - assert doc.error.kind == errors.ContainerError.Kind.RECOVERABLE - assert "errorERRORerror" in doc.error.message - - -@scopes.system_entry_point -def test_task_user_failure(): - m = MagicMock() - m.execute = _raise_user_exception - - with utils.AutoDeletingTempDir("test") as tmp: - engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) - - doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file( - errors_pb2.ErrorDocument, - os.path.join(tmp.name, constants.ERROR_FILE_NAME), - ) - ) - assert doc.error.code == "USER:Unknown" - assert doc.error.kind == errors.ContainerError.Kind.NON_RECOVERABLE - assert "userUSERuser" in doc.error.message - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_execution_notification_overrides(mock_client_factory): - mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") - ) - - engine.FlyteLaunchPlan(m).launch("xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[]) - - mock_client.create_execution.assert_called_once_with( - "xp", - "xd", - "xn", - _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version", - ), - _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), - disable_all=True, - ), - literals.LiteralMap({}), - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_execution_notification_soft_overrides(mock_client_factory): - mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") - ) - - notification = _common_models.Notification([0, 1, 2], email=_common_models.EmailNotification(["me@place.com"])) - - engine.FlyteLaunchPlan(m).launch("xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[notification]) - - mock_client.create_execution.assert_called_once_with( - "xp", - "xd", - "xn", - _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version", - ), - _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), - notifications=_execution_models.NotificationList([notification]), - ), - literals.LiteralMap({}), - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_execution_label_overrides(mock_client_factory): - mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") - ) - - labels = _common_models.Labels({"my": "label"}) - engine.FlyteLaunchPlan(m).execute( - "xp", - "xd", - "xn", - literals.LiteralMap({}), - notification_overrides=[], - label_overrides=labels, - ) - - mock_client.create_execution.assert_called_once_with( - "xp", - "xd", - "xn", - _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version", - ), - _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), - disable_all=True, - labels=labels, - ), - literals.LiteralMap({}), - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_execution_annotation_overrides(mock_client_factory): - mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") - ) - - annotations = _common_models.Annotations({"my": "annotation"}) - engine.FlyteLaunchPlan(m).launch( - "xp", - "xd", - "xn", - literals.LiteralMap({}), - notification_overrides=[], - annotation_overrides=annotations, - ) - - mock_client.create_execution.assert_called_once_with( - "xp", - "xd", - "xn", - _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version", - ), - _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), - disable_all=True, - annotations=annotations, - ), - literals.LiteralMap({}), - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_fetch_launch_plan(mock_client_factory): - mock_client = MagicMock() - mock_client.get_launch_plan = MagicMock( - return_value=_launch_plan_models.LaunchPlan( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p1", "d1", "n1", "v1"), - MagicMock(), - MagicMock(), - ) - ) - mock_client_factory.return_value = mock_client - - lp = engine.FlyteEngineFactory().fetch_launch_plan( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p", "d", "n", "v") - ) - assert lp.id == identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p1", "d1", "n1", "v1") - - mock_client.get_launch_plan.assert_called_once_with( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p", "d", "n", "v") - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_fetch_active_launch_plan(mock_client_factory): - mock_client = MagicMock() - mock_client.get_active_launch_plan = MagicMock( - return_value=_launch_plan_models.LaunchPlan( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p1", "d1", "n1", "v1"), - MagicMock(), - MagicMock(), - ) - ) - mock_client_factory.return_value = mock_client - - lp = engine.FlyteEngineFactory().fetch_launch_plan( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p", "d", "n", "") - ) - assert lp.id == identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p1", "d1", "n1", "v1") - - mock_client.get_active_launch_plan.assert_called_once_with(_common_models.NamedEntityIdentifier("p", "d", "n")) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_full_execution_inputs(mock_client_factory): - mock_client = MagicMock() - mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse( - None, - None, - _INPUT_MAP, - _OUTPUT_MAP, - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) - ) - - inputs = engine.FlyteWorkflowExecution(m).get_inputs() - assert len(inputs.literals) == 1 - assert inputs.literals["a"].scalar.primitive.integer == 1 - mock_client.get_execution_data.assert_called_once_with( - identifier.WorkflowExecutionIdentifier("project", "domain", "name") - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_execution_inputs(mock_client_factory, execution_data_locations): - mock_client = MagicMock() - mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) - ) - - inputs = engine.FlyteWorkflowExecution(m).get_inputs() - assert len(inputs.literals) == 1 - assert inputs.literals["a"].scalar.primitive.integer == 1 - mock_client.get_execution_data.assert_called_once_with( - identifier.WorkflowExecutionIdentifier("project", "domain", "name") - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_full_execution_outputs(mock_client_factory): - mock_client = MagicMock() - mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) - ) - - outputs = engine.FlyteWorkflowExecution(m).get_outputs() - assert len(outputs.literals) == 1 - assert outputs.literals["b"].scalar.primitive.integer == 2 - mock_client.get_execution_data.assert_called_once_with( - identifier.WorkflowExecutionIdentifier("project", "domain", "name") - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_execution_outputs(mock_client_factory, execution_data_locations): - mock_client = MagicMock() - mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) - ) - - inputs = engine.FlyteWorkflowExecution(m).get_outputs() - assert len(inputs.literals) == 1 - assert inputs.literals["b"].scalar.primitive.integer == 2 - mock_client.get_execution_data.assert_called_once_with( - identifier.WorkflowExecutionIdentifier("project", "domain", "name") - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_full_node_execution_inputs(mock_client_factory): - mock_client = MagicMock() - mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse( - None, - None, - _INPUT_MAP, - _OUTPUT_MAP, - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - inputs = engine.FlyteNodeExecution(m).get_inputs() - assert len(inputs.literals) == 1 - assert inputs.literals["a"].scalar.primitive.integer == 1 - mock_client.get_node_execution_data.assert_called_once_with( - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_node_execution_inputs(mock_client_factory, execution_data_locations): - mock_client = MagicMock() - mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - inputs = engine.FlyteNodeExecution(m).get_inputs() - assert len(inputs.literals) == 1 - assert inputs.literals["a"].scalar.primitive.integer == 1 - mock_client.get_node_execution_data.assert_called_once_with( - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_full_node_execution_outputs(mock_client_factory): - mock_client = MagicMock() - mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - outputs = engine.FlyteNodeExecution(m).get_outputs() - assert len(outputs.literals) == 1 - assert outputs.literals["b"].scalar.primitive.integer == 2 - mock_client.get_node_execution_data.assert_called_once_with( - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_node_execution_outputs(mock_client_factory, execution_data_locations): - mock_client = MagicMock() - mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - inputs = engine.FlyteNodeExecution(m).get_outputs() - assert len(inputs.literals) == 1 - assert inputs.literals["b"].scalar.primitive.integer == 2 - mock_client.get_node_execution_data.assert_called_once_with( - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ) - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_full_task_execution_inputs(mock_client_factory): - mock_client = MagicMock() - mock_client.get_task_execution_data = MagicMock( - return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - inputs = engine.FlyteTaskExecution(m).get_inputs() - assert len(inputs.literals) == 1 - assert inputs.literals["a"].scalar.primitive.integer == 1 - mock_client.get_task_execution_data.assert_called_once_with( - identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_task_execution_inputs(mock_client_factory, execution_data_locations): - mock_client = MagicMock() - mock_client.get_task_execution_data = MagicMock( - return_value=_execution_models.TaskExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - inputs = engine.FlyteTaskExecution(m).get_inputs() - assert len(inputs.literals) == 1 - assert inputs.literals["a"].scalar.primitive.integer == 1 - mock_client.get_task_execution_data.assert_called_once_with( - identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_full_task_execution_outputs(mock_client_factory): - mock_client = MagicMock() - mock_client.get_task_execution_data = MagicMock( - return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - outputs = engine.FlyteTaskExecution(m).get_outputs() - assert len(outputs.literals) == 1 - assert outputs.literals["b"].scalar.primitive.integer == 2 - mock_client.get_task_execution_data.assert_called_once_with( - identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_get_task_execution_outputs(mock_client_factory, execution_data_locations): - mock_client = MagicMock() - mock_client.get_task_execution_data = MagicMock( - return_value=_execution_models.TaskExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP - ) - ) - mock_client_factory.return_value = mock_client - - m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - inputs = engine.FlyteTaskExecution(m).get_outputs() - assert len(inputs.literals) == 1 - assert inputs.literals["b"].scalar.primitive.integer == 2 - mock_client.get_task_execution_data.assert_called_once_with( - identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - "project", - "domain", - "task-name", - "version", - ), - identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ), - ), - 0, - ) - ) - - -@pytest.mark.parametrize( - "tasks", - [ - [ - _task_models.Task( - identifier.Identifier(identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), - MagicMock(), - ) - ], - [], - ], -) -@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) -def test_fetch_latest_task(mock_client_factory, tasks): - mock_client = MagicMock() - mock_client.list_tasks_paginated = MagicMock(return_value=(tasks, 0)) - mock_client_factory.return_value = mock_client - - task = engine.FlyteEngineFactory().fetch_latest_task(_common_models.NamedEntityIdentifier("p", "d", "n")) - - if tasks: - assert task.id == tasks[0].id - else: - assert not task - - mock_client.list_tasks_paginated.assert_called_once_with( - _common_models.NamedEntityIdentifier("p", "d", "n"), - limit=1, - sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), - ) diff --git a/tests/flytekit/unit/engines/test_loader.py b/tests/flytekit/unit/engines/test_loader.py deleted file mode 100644 index 5f6ee2d46c..0000000000 --- a/tests/flytekit/unit/engines/test_loader.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest - -from flytekit.engines import loader -from flytekit.engines.unit import engine as _unit_engine - - -def test_unit_load(): - assert isinstance(loader.get_engine("unit"), _unit_engine.UnitTestEngineFactory) - - -def test_bad_load(): - with pytest.raises(Exception): - loader.get_engine("badname") diff --git a/tests/flytekit/unit/engines/unit/__init__.py b/tests/flytekit/unit/engines/unit/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/common/core/__init__.py b/tests/flytekit/unit/exceptions/__init__.py similarity index 100% rename from flytekit/common/core/__init__.py rename to tests/flytekit/unit/exceptions/__init__.py diff --git a/tests/flytekit/unit/common_tests/exceptions/test_base.py b/tests/flytekit/unit/exceptions/test_base.py similarity index 85% rename from tests/flytekit/unit/common_tests/exceptions/test_base.py rename to tests/flytekit/unit/exceptions/test_base.py index f4ede26b74..76b6465d8b 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_base.py +++ b/tests/flytekit/unit/exceptions/test_base.py @@ -1,4 +1,4 @@ -from flytekit.common.exceptions import base +from flytekit.exceptions import base def test_flyte_exception(): diff --git a/tests/flytekit/unit/common_tests/exceptions/test_scopes.py b/tests/flytekit/unit/exceptions/test_scopes.py similarity index 98% rename from tests/flytekit/unit/common_tests/exceptions/test_scopes.py rename to tests/flytekit/unit/exceptions/test_scopes.py index f14ced33f9..75ef74383e 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_scopes.py +++ b/tests/flytekit/unit/exceptions/test_scopes.py @@ -1,6 +1,6 @@ import pytest -from flytekit.common.exceptions import scopes, system, user +from flytekit.exceptions import scopes, system, user from flytekit.models.core import errors as _error_models diff --git a/tests/flytekit/unit/common_tests/exceptions/test_system.py b/tests/flytekit/unit/exceptions/test_system.py similarity index 98% rename from tests/flytekit/unit/common_tests/exceptions/test_system.py rename to tests/flytekit/unit/exceptions/test_system.py index d53ed00f6c..2543af320a 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_system.py +++ b/tests/flytekit/unit/exceptions/test_system.py @@ -1,4 +1,4 @@ -from flytekit.common.exceptions import base, system +from flytekit.exceptions import base, system def test_flyte_system_exception(): diff --git a/tests/flytekit/unit/common_tests/exceptions/test_user.py b/tests/flytekit/unit/exceptions/test_user.py similarity index 98% rename from tests/flytekit/unit/common_tests/exceptions/test_user.py rename to tests/flytekit/unit/exceptions/test_user.py index e3b3fbd319..78dc723def 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_user.py +++ b/tests/flytekit/unit/exceptions/test_user.py @@ -1,4 +1,4 @@ -from flytekit.common.exceptions import base, user +from flytekit.exceptions import base, user def test_flyte_user_exception(): diff --git a/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py b/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py index 0737cd5b92..b52978dd58 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py +++ b/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py @@ -1,8 +1,8 @@ from collections import OrderedDict -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import Image, ImageConfig +from flytekit.tools.translator import get_serializable from tests.flytekit.unit.extras.sqlite3.test_task import tk as not_tk diff --git a/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py b/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py deleted file mode 100644 index ea799ccf4e..0000000000 --- a/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py +++ /dev/null @@ -1,80 +0,0 @@ -import os as _os - -import mock as _mock -import pytest as _pytest - -from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy - - -@_pytest.fixture -def mock_update_cmd_config_and_execute(): - p = _mock.patch("flytekit.interfaces.data.gcs.gcs_proxy._update_cmd_config_and_execute") - yield p.start() - p.stop() - - -@_pytest.fixture -def gsutil_parallelism(): - p = _mock.patch("flytekit.configuration.gcp.GSUTIL_PARALLELISM.get", return_value=True) - yield p.start() - p.stop() - - -@_pytest.fixture -def gcs_proxy(): - return _gcs_proxy.GCSProxy() - - -def test_upload_directory(mock_update_cmd_config_and_execute, gcs_proxy): - local_path, remote_path = "/foo/*", "gs://bar/0/" - gcs_proxy.upload_directory(local_path, remote_path) - mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "cp", "-r", local_path, remote_path]) - - -def test_upload_directory_padding_wildcard_for_local_path(mock_update_cmd_config_and_execute, gcs_proxy): - local_path, remote_path = "/foo", "gs://bar/0/" - gcs_proxy.upload_directory(local_path, remote_path) - mock_update_cmd_config_and_execute.assert_called_once_with( - ["gsutil", "cp", "-r", _os.path.join(local_path, "*"), remote_path] - ) - - -def test_upload_directory_padding_slash_for_remote_path(mock_update_cmd_config_and_execute, gcs_proxy): - local_path, remote_path = "/foo/*", "gs://bar/0" - gcs_proxy.upload_directory(local_path, remote_path) - mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "cp", "-r", local_path, remote_path + "/"]) - - -def test_maybe_with_gsutil_parallelism_disabled(gcs_proxy): - local_path, remote_path = "foo", "gs://bar/0/" - cmd = gcs_proxy._maybe_with_gsutil_parallelism("cp", local_path, remote_path) - assert cmd == ["gsutil", "cp", local_path, remote_path] - - -def test_maybe_with_gsutil_parallelism_enabled(gsutil_parallelism, gcs_proxy): - local_path, remote_path = "foo", "gs://bar/0/" - cmd = gcs_proxy._maybe_with_gsutil_parallelism("cp", "-r", local_path, remote_path) - assert cmd == ["gsutil", "-m", "cp", "-r", local_path, remote_path] - - -def test_download_with_parallelism(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): - local_path, remote_path = "/foo", "gs://bar/0/" - gcs_proxy.download(remote_path, local_path) - mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "-m", "cp", remote_path, local_path]) - - -def test_upload_directory_with_parallelism(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): - local_path, remote_path = "/foo/*", "gs://bar/0/" - gcs_proxy.upload_directory(local_path, remote_path) - mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "-m", "cp", "-r", local_path, remote_path]) - - -def test_raw_prefix_property(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): - gcs_with_raw_prefix = _gcs_proxy.GCSProxy("gcs://stuff") - assert gcs_with_raw_prefix.raw_output_data_prefix_override == "gcs://stuff" - - -def test_random_path(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): - gcs_with_raw_prefix = _gcs_proxy.GCSProxy("gcs://stuff") - result = gcs_with_raw_prefix.get_random_path() - assert result.startswith("gcs://stuff") diff --git a/tests/flytekit/unit/interfaces/data/s3/test_s3_proxy.py b/tests/flytekit/unit/interfaces/data/s3/test_s3_proxy.py deleted file mode 100644 index 7493cd3c81..0000000000 --- a/tests/flytekit/unit/interfaces/data/s3/test_s3_proxy.py +++ /dev/null @@ -1,36 +0,0 @@ -import mock as _mock - -from flytekit.interfaces.data.s3.s3proxy import AwsS3Proxy as _AwsS3Proxy - - -def test_property(): - aws = _AwsS3Proxy("s3://raw-output") - assert aws.raw_output_data_prefix_override == "s3://raw-output" - - -@_mock.patch("flytekit.configuration.aws.S3_SHARD_FORMATTER") -def test_random_path(mock_formatter): - mock_formatter.get.return_value = "s3://flyte/{}/" - - # Without raw output data prefix override - aws = _AwsS3Proxy() - p = str(aws.get_random_path()) - assert p.startswith("s3://flyte") - - # With override - aws = _AwsS3Proxy("s3://raw-output") - p = str(aws.get_random_path()) - assert p.startswith("s3://raw-output") - - -@_mock.patch("flytekit.interfaces.data.s3.s3proxy.AwsS3Proxy._check_binary") -@_mock.patch("flytekit.configuration.aws.BACKOFF_SECONDS") -@_mock.patch("flytekit.interfaces.data.s3.s3proxy._subprocess") -def test_retries(mock_subprocess, mock_delay, mock_check): - mock_delay.get.return_value = 0 - mock_subprocess.check_call.side_effect = Exception("test exception (404)") - mock_check.return_value = True - - proxy = _AwsS3Proxy() - assert proxy.exists("s3://test/fdsa/fdsa") is False - assert mock_subprocess.check_call.call_count == 4 diff --git a/tests/flytekit/unit/models/test_dynamic_spark.py b/tests/flytekit/unit/models/test_dynamic_spark.py deleted file mode 100644 index c875ef9f71..0000000000 --- a/tests/flytekit/unit/models/test_dynamic_spark.py +++ /dev/null @@ -1,29 +0,0 @@ -from flytekit.common import constants as _sdk_constants -from flytekit.configuration import TemporaryConfiguration -from flytekit.sdk import tasks as _tasks -from flytekit.sdk.types import Types as _Types - - -@_tasks.outputs(o=_Types.Integer) -@_tasks.inputs(i=_Types.Integer) -@_tasks.spark_task(spark_conf={"x": "y"}) -def my_spark_task(ctx, sc, i, o): - pass - - -@_tasks.inputs(num=_Types.Integer) -@_tasks.outputs(out=_Types.Integer) -@_tasks.dynamic_task -def spark_yield_task(wf_params, num, out): - wf_params.logging.info("Running inner task... yielding a launchplan") - t = my_spark_task.with_overrides(new_spark_conf={"a": "b"}) - o = t(i=num) - yield o - out.set(o.outputs.o) - - -def test_spark_yield(): - with TemporaryConfiguration(None, internal_overrides={"image": "fakeimage"}): - outputs = spark_yield_task.unit_test(num=1) - dj_spec = outputs[_sdk_constants.FUTURES_FILE_NAME] - print(dj_spec) diff --git a/tests/flytekit/unit/models/test_dynamic_wfs.py b/tests/flytekit/unit/models/test_dynamic_wfs.py deleted file mode 100644 index 302e87e705..0000000000 --- a/tests/flytekit/unit/models/test_dynamic_wfs.py +++ /dev/null @@ -1,125 +0,0 @@ -from flytekit.common import constants as _sdk_constants -from flytekit.sdk import tasks as _tasks -from flytekit.sdk import workflow as _workflow -from flytekit.sdk.types import Types as _Types - - -@_tasks.inputs(num=_Types.Integer) -@_tasks.outputs(out=_Types.Integer) -@_tasks.python_task -def inner_task(wf_params, num, out): - wf_params.logging.info("Running inner task... setting output to input") - out.set(num) - - -@_workflow.workflow_class() -class IdentityWorkflow(object): - a = _workflow.Input(_Types.Integer, default=5, help="Input for inner workflow") - odd_nums_task = inner_task(num=a) - task_output = _workflow.Output(odd_nums_task.outputs.out, sdk_type=_Types.Integer) - - -id_lp = IdentityWorkflow.create_launch_plan() - - -@_tasks.inputs(num=_Types.Integer) -@_tasks.outputs(out=_Types.Integer) -@_tasks.dynamic_task -def lp_yield_task(wf_params, num, out): - wf_params.logging.info("Running inner task... yielding a launchplan") - identity_lp_execution = id_lp(a=num) - yield identity_lp_execution - out.set(identity_lp_execution.outputs.task_output) - - -def test_dynamic_launch_plan_yielding(): - outputs = lp_yield_task.unit_test(num=10) - # TODO: Currently, Flytekit will not return early and not do anything if there are any workflow nodes detected - # in the output of a dynamic task. - dj_spec = outputs[_sdk_constants.FUTURES_FILE_NAME] - - assert dj_spec.min_successes == 1 - - launch_plan_node = dj_spec.nodes[0] - node_id = launch_plan_node.id - assert "models-test-dynamic-wfs-id-lp" in node_id - assert node_id.endswith("-0") - - # Assert that the output of the dynamic job spec is bound to the single node in the spec, the workflow node - # containing the launch plan - assert dj_spec.outputs[0].var == "out" - assert dj_spec.outputs[0].binding.promise.node_id == node_id - assert dj_spec.outputs[0].binding.promise.var == "task_output" - - -@_tasks.python_task -def empty_task(wf_params): - wf_params.logging.info("Running empty task") - - -@_workflow.workflow_class() -class EmptyWorkflow(object): - empty_task_task_execution = empty_task() - - -constant_workflow_lp = EmptyWorkflow.create_launch_plan() - - -@_tasks.outputs(out=_Types.Integer) -@_tasks.dynamic_task -def lp_yield_empty_wf(wf_params, out): - wf_params.logging.info("Running inner task... yielding a launchplan for empty workflow") - constant_lp_yielding_task_execution = constant_workflow_lp() - yield constant_lp_yielding_task_execution - out.set(42) - - -def test_dynamic_launch_plan_yielding_of_constant_workflow(): - outputs = lp_yield_empty_wf.unit_test() - # TODO: Currently, Flytekit will not return early and not do anything if there are any workflow nodes detected - # in the output of a dynamic task. - dj_spec = outputs[_sdk_constants.FUTURES_FILE_NAME] - - assert len(dj_spec.nodes) == 1 - assert len(dj_spec.outputs) == 1 - assert dj_spec.outputs[0].var == "out" - assert len(outputs.keys()) == 2 - - -@_tasks.inputs(num=_Types.Integer) -@_tasks.python_task -def log_only_task(wf_params, num): - wf_params.logging.info("{} was called".format(num)) - - -@_workflow.workflow_class() -class InputOnlyWorkflow(object): - a = _workflow.Input(_Types.Integer, default=5, help="Input for inner workflow") - log_only_task_execution = log_only_task(num=a) - - -input_only_workflow_lp = InputOnlyWorkflow.create_launch_plan() - - -@_tasks.dynamic_task -def lp_yield_input_only_wf(wf_params): - wf_params.logging.info("Running inner task... yielding a launchplan for input only workflow") - input_only_workflow_lp_execution = input_only_workflow_lp() - yield input_only_workflow_lp_execution - - -def test_dynamic_launch_plan_yielding_of_input_only_workflow(): - outputs = lp_yield_input_only_wf.unit_test() - # TODO: Currently, Flytekit will not return early and not do anything if there are any workflow nodes detected - # in the output of a dynamic task. - dj_spec = outputs[_sdk_constants.FUTURES_FILE_NAME] - - assert len(dj_spec.nodes) == 1 - assert len(dj_spec.outputs) == 0 - assert len(outputs.keys()) == 2 - - # Using the id of the launch plan node, and then appending /inputs.pb to the string, should give you in the outputs - # map the LiteralMap of the inputs of that node - input_key = "{}/inputs.pb".format(dj_spec.nodes[0].id) - lp_input_map = outputs[input_key] - assert lp_input_map.literals["a"] is not None diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index dec28f0f54..ab0fdec38a 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -4,7 +4,6 @@ import pytest from flyteidl.core.tasks_pb2 import TaskMetadata from google.protobuf import text_format -from k8s.io.api.core.v1 import generated_pb2 import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models @@ -227,36 +226,6 @@ def test_container(resources): assert obj == task.Container.from_flyte_idl(obj.to_flyte_idl()) -def test_sidecar_task(): - pod_spec = generated_pb2.PodSpec() - container = generated_pb2.Container(name="containery") - pod_spec.containers.extend([container]) - obj = task.SidecarJob( - pod_spec=pod_spec, - primary_container_name="primary", - annotations={"a1": "a1"}, - labels={"b1": "b1"}, - ) - assert obj.primary_container_name == "primary" - assert len(obj.pod_spec.containers) == 1 - assert obj.pod_spec.containers[0].name == "containery" - assert obj.annotations["a1"] == "a1" - assert obj.labels["b1"] == "b1" - - obj2 = task.SidecarJob.from_flyte_idl(obj.to_flyte_idl()) - assert obj2 == obj - - -def test_sidecar_task_label_annotation_not_provided(): - pod_spec = generated_pb2.PodSpec() - obj = task.SidecarJob(pod_spec=pod_spec, primary_container_name="primary") - - assert obj.primary_container_name == "primary" - - obj2 = task.SidecarJob.from_flyte_idl(obj.to_flyte_idl()) - assert obj2 == obj - - def test_dataloadingconfig(): dlc = task.DataLoadingConfig( "s3://input/path", diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 5904d2efdf..cd80d166e2 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -3,8 +3,8 @@ import pytest from mock import MagicMock, patch -from flytekit.common.exceptions import user as user_exceptions from flytekit.configuration import internal +from flytekit.exceptions import user as user_exceptions from flytekit.models import common as common_models from flytekit.models.core.identifier import ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution diff --git a/tests/flytekit/unit/remote/test_wrapper_classes.py b/tests/flytekit/unit/remote/test_wrapper_classes.py index b26253e1db..f229489b14 100644 --- a/tests/flytekit/unit/remote/test_wrapper_classes.py +++ b/tests/flytekit/unit/remote/test_wrapper_classes.py @@ -3,7 +3,6 @@ import pytest -from flytekit.common.translator import gather_dependent_entities, get_serializable from flytekit.core import context_manager from flytekit.core.condition import conditional from flytekit.core.context_manager import Image, ImageConfig @@ -11,6 +10,7 @@ from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.remote import FlyteWorkflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( diff --git a/tests/flytekit/unit/sdk/__init__.py b/tests/flytekit/unit/sdk/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/sdk/conftest.py b/tests/flytekit/unit/sdk/conftest.py deleted file mode 100644 index 874a288263..0000000000 --- a/tests/flytekit/unit/sdk/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -import pytest as _pytest - -from flytekit.configuration import TemporaryConfiguration - - -@_pytest.fixture(scope="function", autouse=True) -def set_fake_config(): - with TemporaryConfiguration(None, internal_overrides={"image": "fakeimage"}): - yield diff --git a/tests/flytekit/unit/sdk/tasks/__init__.py b/tests/flytekit/unit/sdk/tasks/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py deleted file mode 100644 index 2de75de96c..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py +++ /dev/null @@ -1,82 +0,0 @@ -import mock -from k8s.io.api.core.v1 import generated_pb2 - -from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.tasks import sidecar_task as _sidecar_task -from flytekit.configuration.internal import IMAGE as _IMAGE -from flytekit.sdk.tasks import dynamic_sidecar_task, inputs, outputs, python_task -from flytekit.sdk.types import Types - - -def get_pod_spec(): - a_container = generated_pb2.Container(name="main") - a_container.command.extend(["foo", "bar"]) - a_container.volumeMounts.extend( - [ - generated_pb2.VolumeMount( - name="scratch", - mountPath="/scratch", - ) - ] - ) - - pod_spec = generated_pb2.PodSpec( - restartPolicy="Never", - ) - pod_spec.containers.extend([a_container, generated_pb2.Container(name="sidecar")]) - return pod_spec - - -with mock.patch.object(_IMAGE, "get", return_value="docker.io/blah:abc123"): - - @outputs(out1=Types.String) - @python_task - def simple_python_task(wf_params, out1): - out1.set("test") - - @inputs(in1=Types.Integer) - @outputs(out1=Types.String) - @dynamic_sidecar_task( - cpu_request="10", - memory_limit="2Gi", - environment={"foo": "bar"}, - pod_spec=get_pod_spec(), - primary_container_name="main", - ) - def simple_dynamic_sidecar_task(wf_params, in1, out1): - yield simple_python_task() - - -def test_dynamic_sidecar_task(): - assert isinstance(simple_dynamic_sidecar_task, _sdk_runnable.SdkRunnableTask) - assert isinstance(simple_dynamic_sidecar_task, _sidecar_task.SdkDynamicSidecarTask) - assert isinstance(simple_dynamic_sidecar_task, _sidecar_task.SdkSidecarTask) - assert isinstance(simple_dynamic_sidecar_task, _sdk_dynamic.SdkDynamicTaskMixin) - - pod_spec = simple_dynamic_sidecar_task.custom["podSpec"] - assert pod_spec["restartPolicy"] == "Never" - assert len(pod_spec["containers"]) == 2 - primary_container = pod_spec["containers"][0] - assert primary_container["name"] == "main" - assert primary_container["args"] == [ - "pyflyte-execute", - "--task-module", - "tests.flytekit.unit.sdk.tasks.test_dynamic_sidecar_tasks", - "--task-name", - "simple_dynamic_sidecar_task", - "--inputs", - "{{.input}}", - "--output-prefix", - "{{.outputPrefix}}", - "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", - ] - assert primary_container["volumeMounts"] == [{"mountPath": "/scratch", "name": "scratch"}] - assert {"name": "foo", "value": "bar"} in primary_container["env"] - assert primary_container["resources"] == { - "requests": {"cpu": {"string": "10"}}, - "limits": {"memory": {"string": "2Gi"}}, - } - assert pod_spec["containers"][1]["name"] == "sidecar" - assert simple_dynamic_sidecar_task.custom["primaryContainerName"] == "main" diff --git a/tests/flytekit/unit/sdk/tasks/test_dynamic_tasks.py b/tests/flytekit/unit/sdk/tasks/test_dynamic_tasks.py deleted file mode 100644 index 44ded2e3bd..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_dynamic_tasks.py +++ /dev/null @@ -1,241 +0,0 @@ -from six import moves as _six_moves - -from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.sdk.tasks import dynamic_task, inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow - - -@inputs(in1=Types.Integer) -@outputs(out_str=[Types.String], out_ints=[[Types.Integer]]) -@dynamic_task -def sample_batch_task(wf_params, in1, out_str, out_ints): - res = ["I'm the first result"] - for i in _six_moves.range(0, in1): - task = sub_task(in1=i) - yield task - res.append(task.outputs.out1) - res.append("I'm after each sub-task result") - res.append("I'm the last result") - - res2 = [] - for i in _six_moves.range(0, in1): - task = int_sub_task(in1=i) - yield task - res2.append(task.outputs.out1) - - # Nested batch tasks - task = sample_batch_task_sq() - yield task - res2.append(task.outputs.out_ints) - - task = sample_batch_task_sq() - yield task - res2.append(task.outputs.out_ints) - - out_str.set(res) - out_ints.set(res2) - - -@outputs(out_ints=[Types.Integer]) -@dynamic_task -def sample_batch_task_sq(wf_params, out_ints): - res2 = [] - for i in _six_moves.range(0, 3): - task = sq_sub_task(in1=i) - yield task - res2.append(task.outputs.out1) - out_ints.set(res2) - - -@outputs(out_str=[Types.String], out_ints=[[Types.Integer]]) -@dynamic_task -def sample_batch_task_no_inputs(wf_params, out_str, out_ints): - res = ["I'm the first result"] - for i in _six_moves.range(0, 3): - task = sub_task(in1=i) - yield task - res.append(task.outputs.out1) - res.append("I'm after each sub-task result") - res.append("I'm the last result") - - res2 = [] - for i in _six_moves.range(0, 3): - task = int_sub_task(in1=i) - yield task - res2.append(task.outputs.out1) - - # Nested batch tasks - task = sample_batch_task_sq() - yield task - res2.append(task.outputs.out_ints) - - task = sample_batch_task_sq() - yield task - res2.append(task.outputs.out_ints) - - out_str.set(res) - out_ints.set(res2) - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.String) -@python_task -def sub_task(wf_params, in1, out1): - out1.set("hello {}".format(in1)) - - -@inputs(in1=Types.Integer) -@outputs(out1=[Types.Integer]) -@python_task -def int_sub_task(wf_params, in1, out1): - wf_params.stats.incr("int_sub_task") - out1.set([in1, in1 * 2, in1 * 3]) - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.Integer) -@python_task -def sq_sub_task(wf_params, in1, out1): - out1.set(in1 * in1) - - -@inputs(in1=Types.Integer) -@outputs(out_str=[Types.String]) -@dynamic_task -def no_future_batch_task(wf_params, in1, out_str): - out_str.set(["res1", "res2"]) - - -def manual_assign_name(): - pass - - -@inputs(task_input_num=Types.Integer) -@outputs(out=Types.Integer) -@dynamic_task -def dynamic_wf_task(wf_params, task_input_num, out): - wf_params.logging.info("Running inner task... yielding a code generated sub workflow") - - input_a = Input(Types.Integer, help="Tell me something") - node1 = sq_sub_task(in1=input_a) - - MyUnregisteredWorkflow = workflow( - inputs={"a": input_a}, - outputs={"ooo": Output(node1.outputs.out1, sdk_type=Types.Integer, help="This is an integer output")}, - nodes={"node_one": node1}, - ) - - setattr(MyUnregisteredWorkflow, "auto_assign_name", manual_assign_name) - MyUnregisteredWorkflow._platform_valid_name = "unregistered" - - unregistered_workflow_execution = MyUnregisteredWorkflow(a=task_input_num) - out.set(unregistered_workflow_execution.outputs.ooo) - - -def test_batch_task(): - assert isinstance(sample_batch_task, _sdk_runnable.SdkRunnableTask) - assert isinstance(sample_batch_task, _sdk_dynamic.SdkDynamicTask) - assert isinstance(sample_batch_task, _sdk_dynamic.SdkDynamicTaskMixin) - - expected = { - "out_str": [ - "I'm the first result", - "hello 0", - "I'm after each sub-task result", - "hello 1", - "I'm after each sub-task result", - "hello 2", - "I'm after each sub-task result", - "I'm the last result", - ], - "out_ints": [[0, 0, 0], [1, 2, 3], [2, 4, 6], [0, 1, 4], [0, 1, 4]], - } - - res = sample_batch_task.unit_test(in1=3) - assert expected == res - - -def test_no_future_batch_task(): - expected = {"out_str": ["res1", "res2"]} - - res = no_future_batch_task.unit_test(in1=3) - assert expected == res - - -def test_dynamic_workflow(): - res = dynamic_wf_task.unit_test(task_input_num=2) - dynamic_spec = res["futures.pb"] - assert len(dynamic_spec.nodes) == 1 - assert len(dynamic_spec.subworkflows) == 1 - assert len(dynamic_spec.tasks) == 1 - - -@inputs(task_input_num=Types.Integer) -@outputs(out=Types.Integer) -@dynamic_task -def nested_dynamic_wf_task(wf_params, task_input_num, out): - wf_params.logging.info("Running inner task... yielding a code generated sub workflow") - - # Inner workflow - input_a = Input(Types.Integer, help="Tell me something") - node1 = sq_sub_task(in1=input_a) - - MyUnregisteredWorkflowInner = workflow( - inputs={"a": input_a}, - outputs={"ooo": Output(node1.outputs.out1, sdk_type=Types.Integer, help="This is an integer output")}, - nodes={"node_one": node1}, - ) - - setattr(MyUnregisteredWorkflowInner, "auto_assign_name", manual_assign_name) - MyUnregisteredWorkflowInner._platform_valid_name = "unregistered" - - # Output workflow - input_a = Input(Types.Integer, help="Tell me something") - node1 = MyUnregisteredWorkflowInner(a=task_input_num) - - MyUnregisteredWorkflowOuter = workflow( - inputs={"a": input_a}, - outputs={"ooo": Output(node1.outputs.ooo, sdk_type=Types.Integer, help="This is an integer output")}, - nodes={"node_one": node1}, - ) - - setattr(MyUnregisteredWorkflowOuter, "auto_assign_name", manual_assign_name) - MyUnregisteredWorkflowOuter._platform_valid_name = "unregistered" - - unregistered_workflow_execution = MyUnregisteredWorkflowOuter(a=task_input_num) - out.set(unregistered_workflow_execution.outputs.ooo) - - -def test_nested_dynamic_workflow(): - res = nested_dynamic_wf_task.unit_test(task_input_num=2) - dynamic_spec = res["futures.pb"] - assert len(dynamic_spec.nodes) == 1 - assert len(dynamic_spec.subworkflows) == 2 - assert len(dynamic_spec.tasks) == 1 - - -@inputs(task_input_num=Types.Integer) -@dynamic_task -def dynamic_wf_no_outputs_task(wf_params, task_input_num): - wf_params.logging.info("Running inner task... yielding a code generated sub workflow") - - input_a = Input(Types.Integer, help="Tell me something") - node1 = sq_sub_task(in1=input_a) - - MyUnregisteredWorkflow = workflow(inputs={"a": input_a}, outputs={}, nodes={"node_one": node1}) - - setattr(MyUnregisteredWorkflow, "auto_assign_name", manual_assign_name) - MyUnregisteredWorkflow._platform_valid_name = "unregistered" - - unregistered_workflow_execution = MyUnregisteredWorkflow(a=task_input_num) - yield unregistered_workflow_execution - - -def test_dynamic_workflow_no_outputs(): - res = dynamic_wf_no_outputs_task.unit_test(task_input_num=2) - dynamic_spec = res["futures.pb"] - assert len(dynamic_spec.nodes) == 1 - assert len(dynamic_spec.subworkflows) == 1 - assert len(dynamic_spec.tasks) == 1 diff --git a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py deleted file mode 100644 index e81edec5de..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py +++ /dev/null @@ -1,141 +0,0 @@ -import logging as _logging -from datetime import datetime as _datetime - -import six as _six - -from flytekit.common import utils as _common_utils -from flytekit.common.tasks import hive_task as _hive_task -from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.common.types import containers as _containers -from flytekit.common.types import helpers as _type_helpers -from flytekit.common.types import schema as _schema -from flytekit.common.types.impl.schema import Schema -from flytekit.engines import common as _common_engine -from flytekit.models import literals as _literals -from flytekit.models.core.identifier import WorkflowExecutionIdentifier -from flytekit.sdk.tasks import hive_task, inputs, outputs, qubole_hive_task -from flytekit.sdk.types import Types - - -@hive_task(cache_version="1") -def sample_hive_task_no_input(wf_params): - return _six.text_type("select 5") - - -@inputs(in1=Types.Integer) -@hive_task(cache_version="1") -def sample_hive_task(wf_params, in1): - return _six.text_type("select ") + _six.text_type(in1) - - -@hive_task -def sample_hive_task_no_queries(wf_params): - return [] - - -@qubole_hive_task( - cache_version="1", - cluster_label=_six.text_type("cluster_label"), - tags=[], -) -def sample_qubole_hive_task_no_input(wf_params): - return _six.text_type("select 5") - - -@inputs(in1=Types.Integer) -@qubole_hive_task( - cache_version="1", - cluster_label=_six.text_type("cluster_label"), - tags=[_six.text_type("tag1")], -) -def sample_qubole_hive_task(wf_params, in1): - return _six.text_type("select ") + _six.text_type(in1) - - -def test_hive_task(): - assert isinstance(sample_hive_task, _sdk_runnable.SdkRunnableTask) - assert isinstance(sample_hive_task, _hive_task.SdkHiveTask) - - sample_hive_task.unit_test(in1=5) - - -@outputs(hive_results=[Types.Schema()]) -@qubole_hive_task -def two_queries(wf_params, hive_results): - q1 = "SELECT 1" - q2 = "SELECT 'two'" - schema_1, formatted_query_1 = Schema.create_from_hive_query(select_query=q1) - schema_2, formatted_query_2 = Schema.create_from_hive_query(select_query=q2) - - hive_results.set([schema_1, schema_2]) - return [formatted_query_1, formatted_query_2] - - -def test_interface_setup(): - outs = two_queries.interface.outputs - assert outs["hive_results"].type.collection_type is not None - assert outs["hive_results"].type.collection_type.schema is not None - assert outs["hive_results"].type.collection_type.schema.columns == [] - - -def test_sdk_output_references_construction(): - references = { - name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) - for name, variable in _six.iteritems(two_queries.interface.outputs) - } - # Before user code is run, the outputs passed to the user code should not have values - assert references["hive_results"].sdk_value == _base_sdk_types.Void() - - # Should be a list of schemas - assert isinstance(references["hive_results"].sdk_type, _containers.TypedCollectionType) - assert isinstance(references["hive_results"].sdk_type.sub_type, _schema.SchemaInstantiator) - - -def test_hive_task_query_generation(): - with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: - context = _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), - execution_date=_datetime.utcnow(), - stats=None, # TODO: A mock stats object that we can read later. - logging=_logging, # TODO: A mock logging object that we can read later. - tmp_dir=user_working_directory, - ) - references = { - name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) - for name, variable in _six.iteritems(two_queries.interface.outputs) - } - - qubole_hive_jobs = two_queries._generate_plugin_objects(context, references) - assert len(qubole_hive_jobs) == 2 - - # deprecated, collection is only here for backwards compatibility - assert len(qubole_hive_jobs[0].query_collection.queries) == 1 - assert len(qubole_hive_jobs[1].query_collection.queries) == 1 - - # The output references should now have the same fake S3 path as the formatted queries - assert references["hive_results"].value[0].uri != "" - assert references["hive_results"].value[1].uri != "" - assert references["hive_results"].value[0].uri in qubole_hive_jobs[0].query.query - assert references["hive_results"].value[1].uri in qubole_hive_jobs[1].query.query - - -def test_hive_task_dynamic_job_spec_generation(): - with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: - context = _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), - execution_date=_datetime.utcnow(), - stats=None, # TODO: A mock stats object that we can read later. - logging=_logging, # TODO: A mock logging object that we can read later. - tmp_dir=user_working_directory, - ) - dj_spec = two_queries._produce_dynamic_job_spec(context, _literals.LiteralMap(literals={})) - - # Bindings - assert len(dj_spec.outputs[0].binding.collection.bindings) == 2 - assert isinstance(dj_spec.outputs[0].binding.collection.bindings[0].scalar.schema, Schema) - assert isinstance(dj_spec.outputs[0].binding.collection.bindings[1].scalar.schema, Schema) - - # Custom field is filled in - assert len(dj_spec.tasks[0].custom) > 0 diff --git a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py deleted file mode 100644 index 10d1eb4fec..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py +++ /dev/null @@ -1,84 +0,0 @@ -import mock -from k8s.io.api.core.v1 import generated_pb2 - -from flytekit.common.tasks import sidecar_task as _sidecar_task -from flytekit.common.tasks import task as _sdk_task -from flytekit.configuration.internal import IMAGE as _IMAGE -from flytekit.models.core import identifier as _identifier -from flytekit.sdk.tasks import inputs, outputs, sidecar_task -from flytekit.sdk.types import Types - - -def get_pod_spec(): - a_container = generated_pb2.Container( - name="a container", - ) - a_container.command.extend(["fee", "fi", "fo", "fum"]) - a_container.volumeMounts.extend( - [ - generated_pb2.VolumeMount( - name="volume mount", - mountPath="some/where", - ) - ] - ) - - pod_spec = generated_pb2.PodSpec( - restartPolicy="OnFailure", - ) - pod_spec.containers.extend([a_container, generated_pb2.Container(name="another container")]) - return pod_spec - - -with mock.patch.object(_IMAGE, "get", return_value="docker.io/blah:abc123"): - - @inputs(in1=Types.Integer) - @outputs(out1=Types.String) - @sidecar_task( - cpu_request="10", - gpu_limit="2", - environment={"foo": "bar"}, - pod_spec=get_pod_spec(), - primary_container_name="a container", - annotations={"a": "a"}, - labels={"b": "b"}, - ) - def simple_sidecar_task(wf_params, in1, out1): - pass - - -simple_sidecar_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") - - -def test_sidecar_task(): - assert isinstance(simple_sidecar_task, _sdk_task.SdkTask) - assert isinstance(simple_sidecar_task, _sidecar_task.SdkSidecarTask) - - pod_spec = simple_sidecar_task.custom["podSpec"] - assert pod_spec["restartPolicy"] == "OnFailure" - assert len(pod_spec["containers"]) == 2 - primary_container = pod_spec["containers"][0] - assert primary_container["name"] == "a container" - assert primary_container["args"] == [ - "pyflyte-execute", - "--task-module", - "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", - "--task-name", - "simple_sidecar_task", - "--inputs", - "{{.input}}", - "--output-prefix", - "{{.outputPrefix}}", - "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", - ] - assert primary_container["volumeMounts"] == [{"mountPath": "some/where", "name": "volume mount"}] - assert {"name": "foo", "value": "bar"} in primary_container["env"] - assert primary_container["resources"] == { - "requests": {"cpu": {"string": "10"}}, - "limits": {"gpu": {"string": "2"}}, - } - assert pod_spec["containers"][1]["name"] == "another container" - assert simple_sidecar_task.custom["primaryContainerName"] == "a container" - assert simple_sidecar_task.custom["annotations"]["a"] == "a" - assert simple_sidecar_task.custom["labels"]["b"] == "b" diff --git a/tests/flytekit/unit/sdk/tasks/test_spark_task.py b/tests/flytekit/unit/sdk/tasks/test_spark_task.py deleted file mode 100644 index cfccd5ecde..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_spark_task.py +++ /dev/null @@ -1,78 +0,0 @@ -import datetime as _datetime -import os as _os -import sys as _sys - -from flytekit.bin import entrypoint as _entrypoint -from flytekit.common import constants as _common_constants -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.tasks import spark_task as _spark_task -from flytekit.models import types as _type_models -from flytekit.models.core import identifier as _identifier -from flytekit.sdk.tasks import inputs, outputs, spark_task -from flytekit.sdk.types import Types - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.String) -@spark_task(spark_conf={"A": "B"}, hadoop_conf={"C": "D"}) -def default_task(wf_params, sc, in1, out1): - out1.set("hello") - - -default_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") - - -def test_default_python_task(): - assert isinstance(default_task, _spark_task.SdkSparkTask) - assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) - assert default_task.interface.inputs["in1"].description == "" - assert default_task.interface.inputs["in1"].type == _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) - assert default_task.interface.outputs["out1"].description == "" - assert default_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING - ) - assert default_task.type == _common_constants.SdkTaskType.SPARK_TASK - assert default_task.task_function_name == "default_task" - assert default_task.task_module == __name__ - assert default_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert default_task.metadata.deprecated_error_message == "" - assert default_task.metadata.discoverable is False - assert default_task.metadata.discovery_version == "" - assert default_task.metadata.retries.retries == 0 - assert len(default_task.container.resources.limits) == 0 - assert len(default_task.container.resources.requests) == 0 - assert default_task.custom["sparkConf"]["A"] == "B" - assert default_task.custom["hadoopConf"]["C"] == "D" - assert default_task.hadoop_conf["C"] == "D" - assert default_task.spark_conf["A"] == "B" - assert _os.path.abspath(_entrypoint.__file__)[:-1] in default_task.custom["mainApplicationFile"] - assert default_task.custom["executorPath"] == _sys.executable - - pb2 = default_task.to_flyte_idl() - assert pb2.custom["sparkConf"]["A"] == "B" - assert pb2.custom["hadoopConf"]["C"] == "D" - - -def test_overrides_spark_task(): - assert default_task.id.name == "name" - new_task = default_task.with_overrides(new_spark_conf={"x": "1"}, new_hadoop_conf={"y": "2"}) - assert isinstance(new_task, _spark_task.SdkSparkTask) - assert new_task.id.name.startswith("name-") - assert new_task.custom["sparkConf"]["x"] == "1" - assert new_task.custom["hadoopConf"]["y"] == "2" - - assert default_task.custom["sparkConf"]["A"] == "B" - assert default_task.custom["hadoopConf"]["C"] == "D" - - assert default_task.has_valid_name is False - default_task.assign_name("my-task") - assert default_task.has_valid_name - assert new_task.interface == default_task.interface - - assert default_task.__hash__() != new_task.__hash__() - - new_task2 = default_task.with_overrides(new_spark_conf={"x": "1"}, new_hadoop_conf={"y": "2"}) - assert new_task2.id.name == new_task.id.name - - t = new_task(in1=1) - assert t.outputs["out1"] is not None diff --git a/tests/flytekit/unit/sdk/tasks/test_tasks.py b/tests/flytekit/unit/sdk/tasks/test_tasks.py deleted file mode 100644 index 33e0287199..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_tasks.py +++ /dev/null @@ -1,108 +0,0 @@ -import datetime as _datetime -import os as _os - -from flytekit import configuration as _configuration -from flytekit.common import constants as _common_constants -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models -from flytekit.models import types as _type_models -from flytekit.models.core import identifier as _identifier -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.String) -@python_task -def default_task(wf_params, in1, out1): - pass - - -default_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") - - -def test_default_python_task(): - assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) - assert default_task.interface.inputs["in1"].description == "" - assert default_task.interface.inputs["in1"].type == _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) - assert default_task.interface.outputs["out1"].description == "" - assert default_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING - ) - assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK - assert default_task.task_function_name == "default_task" - assert default_task.task_module == __name__ - assert default_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert default_task.metadata.deprecated_error_message == "" - assert default_task.metadata.discoverable is False - assert default_task.metadata.discovery_version == "" - assert default_task.metadata.retries.retries == 0 - assert len(default_task.container.resources.limits) == 0 - assert len(default_task.container.resources.requests) == 0 - - -def test_default_resources(): - with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - "../../configuration/configs/good.config", - ) - ): - - @inputs(in1=Types.Integer) - @outputs(out1=Types.String) - @python_task() - def default_task2(wf_params, in1, out1): - pass - - request_map = {r.name: r.value for r in default_task2.container.resources.requests} - - limit_map = {l.name: l.value for l in default_task2.container.resources.limits} - - assert request_map[_task_models.Resources.ResourceName.CPU] == "500m" - assert request_map[_task_models.Resources.ResourceName.MEMORY] == "500Gi" - assert request_map[_task_models.Resources.ResourceName.GPU] == "1" - assert request_map[_task_models.Resources.ResourceName.STORAGE] == "500Gi" - - assert limit_map[_task_models.Resources.ResourceName.CPU] == "501m" - assert limit_map[_task_models.Resources.ResourceName.MEMORY] == "501Gi" - assert limit_map[_task_models.Resources.ResourceName.GPU] == "2" - assert limit_map[_task_models.Resources.ResourceName.STORAGE] == "501Gi" - - -def test_overriden_resources(): - with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - "../../configuration/configs/good.config", - ) - ): - - @inputs(in1=Types.Integer) - @outputs(out1=Types.String) - @python_task( - memory_limit="100Gi", - memory_request="50Gi", - cpu_limit="1000m", - cpu_request="500m", - gpu_limit="1", - gpu_request="0", - storage_request="100Gi", - storage_limit="200Gi", - ) - def default_task2(wf_params, in1, out1): - pass - - request_map = {r.name: r.value for r in default_task2.container.resources.requests} - - limit_map = {l.name: l.value for l in default_task2.container.resources.limits} - - assert request_map[_task_models.Resources.ResourceName.CPU] == "500m" - assert request_map[_task_models.Resources.ResourceName.MEMORY] == "50Gi" - assert request_map[_task_models.Resources.ResourceName.GPU] == "0" - assert request_map[_task_models.Resources.ResourceName.STORAGE] == "100Gi" - - assert limit_map[_task_models.Resources.ResourceName.CPU] == "1000m" - assert limit_map[_task_models.Resources.ResourceName.MEMORY] == "100Gi" - assert limit_map[_task_models.Resources.ResourceName.GPU] == "1" - assert limit_map[_task_models.Resources.ResourceName.STORAGE] == "200Gi" diff --git a/tests/flytekit/unit/sdk/test_workflow.py b/tests/flytekit/unit/sdk/test_workflow.py deleted file mode 100644 index 2b604b5bf6..0000000000 --- a/tests/flytekit/unit/sdk/test_workflow.py +++ /dev/null @@ -1,150 +0,0 @@ -import pytest - -from flytekit.common import constants -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types, containers, primitives -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, Output, workflow, workflow_class - - -def test_input(): - i = Input(primitives.Integer, help="blah", default=None) - assert i.name == "" - assert i.sdk_default is None - assert i.default == base_sdk_types.Void() - assert i.sdk_required is False - assert i.required is None - assert i.help == "blah" - assert i.var.description == "blah" - assert i.sdk_type == primitives.Integer - - i = i.rename_and_return_reference("new_name") - assert i.name == "new_name" - assert i.sdk_default is None - assert i.default == base_sdk_types.Void() - assert i.sdk_required is False - assert i.required is None - assert i.help == "blah" - assert i.var.description == "blah" - assert i.sdk_type == primitives.Integer - - i = Input(primitives.Integer, default=1) - assert i.name == "" - assert i.sdk_default == 1 - assert i.default == primitives.Integer(1) - assert i.sdk_required is False - assert i.required is None - assert i.help is None - assert i.var.description == "" - assert i.sdk_type == primitives.Integer - - i = i.rename_and_return_reference("new_name") - assert i.name == "new_name" - assert i.sdk_default == 1 - assert i.default == primitives.Integer(1) - assert i.sdk_required is False - assert i.required is None - assert i.help is None - assert i.var.description == "" - assert i.sdk_type == primitives.Integer - - with pytest.raises(_user_exceptions.FlyteAssertion): - Input(primitives.Integer, required=True, default=1) - - i = Input([primitives.Integer], default=[1, 2]) - assert i.name == "" - assert i.sdk_default == [1, 2] - assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)]) - assert i.sdk_required is False - assert i.required is None - assert i.help is None - assert i.var.description == "" - assert i.sdk_type == containers.List(primitives.Integer) - - i = i.rename_and_return_reference("new_name") - assert i.name == "new_name" - assert i.sdk_default == [1, 2] - assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)]) - assert i.sdk_required is False - assert i.required is None - assert i.help is None - assert i.var.description == "" - assert i.sdk_type == containers.List(primitives.Integer) - - -def test_output(): - o = Output(1, sdk_type=primitives.Integer, help="blah") - assert o.name == "" - assert o.var.description == "blah" - assert o.var.type == primitives.Integer.to_flyte_literal_type() - assert o.binding_data.scalar.primitive.integer == 1 - - o = o.rename_and_return_reference("new_name") - assert o.name == "new_name" - assert o.var.description == "blah" - assert o.var.type == primitives.Integer.to_flyte_literal_type() - assert o.binding_data.scalar.primitive.integer == 1 - - -def _get_node_by_id(wf, nid): - for n in wf.nodes: - if n.id == nid: - return n - assert False - - -def test_workflow_no_node_dependencies_or_outputs(): - @inputs(a=Types.Integer) - @outputs(b=Types.Integer) - @python_task - def my_task(wf_params, a, b): - b.set(a + 1) - - i1 = Input(Types.Integer) - i2 = Input(Types.Integer, default=5, help="Not required.") - - input_dict = {"input_1": i1, "input_2": i2} - - nodes = { - "a": my_task(a=input_dict["input_1"]), - "b": my_task(a=input_dict["input_2"]), - "c": my_task(a=100), - } - - w = workflow(inputs=input_dict, outputs={}, nodes=nodes) - - assert w.interface.inputs["input_1"].type == Types.Integer.to_flyte_literal_type() - assert w.interface.inputs["input_2"].type == Types.Integer.to_flyte_literal_type() - assert _get_node_by_id(w, "a").inputs[0].var == "a" - assert _get_node_by_id(w, "a").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(w, "a").inputs[0].binding.promise.var == "input_1" - assert _get_node_by_id(w, "b").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(w, "b").inputs[0].binding.promise.var == "input_2" - assert _get_node_by_id(w, "c").inputs[0].binding.scalar.primitive.integer == 100 - - -def test_workflow_metaclass_no_node_dependencies_or_outputs(): - @inputs(a=Types.Integer) - @outputs(b=Types.Integer) - @python_task - def my_task(wf_params, a, b): - b.set(a + 1) - - @workflow_class - class sup(object): - input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help="Not required.") - - a = my_task(a=input_1) - b = my_task(a=input_2) - c = my_task(a=100) - - assert sup.interface.inputs["input_1"].type == Types.Integer.to_flyte_literal_type() - assert sup.interface.inputs["input_2"].type == Types.Integer.to_flyte_literal_type() - assert _get_node_by_id(sup, "a").inputs[0].var == "a" - assert _get_node_by_id(sup, "a").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(sup, "a").inputs[0].binding.promise.var == "input_1" - assert _get_node_by_id(sup, "b").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(sup, "b").inputs[0].binding.promise.var == "input_2" - assert _get_node_by_id(sup, "c").inputs[0].binding.scalar.primitive.integer == 100 diff --git a/tests/flytekit/unit/sdk/types/__init__.py b/tests/flytekit/unit/sdk/types/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/sdk/types/test_blobs.py b/tests/flytekit/unit/sdk/types/test_blobs.py deleted file mode 100644 index 78e99fe5a7..0000000000 --- a/tests/flytekit/unit/sdk/types/test_blobs.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -from flytekit.common.types.impl import blobs as _blob_impl -from flytekit.sdk import types as _sdk_types - - -@pytest.mark.parametrize( - "blob_tuple", - [ - (_sdk_types.Types.Blob, _blob_impl.Blob), - (_sdk_types.Types.CSV, _blob_impl.Blob), - (_sdk_types.Types.MultiPartBlob, _blob_impl.MultiPartBlob), - (_sdk_types.Types.MultiPartCSV, _blob_impl.MultiPartBlob), - ], -) -def test_instantiable_blobs(blob_tuple): - sdk_type, impl = blob_tuple - - blob_inst = sdk_type() - blob_type_inst = sdk_type(blob_inst) - assert isinstance(blob_inst, impl) - assert isinstance(blob_type_inst, sdk_type) - - with pytest.raises(Exception): - sdk_type(1, 2) - - with pytest.raises(Exception): - sdk_type(a=1) - - blob_inst = sdk_type.create_at_known_location("abc") - assert isinstance(blob_inst, impl) diff --git a/tests/flytekit/unit/sdk/types/test_primitives.py b/tests/flytekit/unit/sdk/types/test_primitives.py deleted file mode 100644 index e039b9780b..0000000000 --- a/tests/flytekit/unit/sdk/types/test_primitives.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest - -from flytekit.sdk import types as _sdk_types - - -def test_integer(): - with pytest.raises(Exception): - _sdk_types.Types.Integer() - - -def test_float(): - with pytest.raises(Exception): - _sdk_types.Types.Float() - - -def test_string(): - with pytest.raises(Exception): - _sdk_types.Types.String() - - -def test_bool(): - with pytest.raises(Exception): - _sdk_types.Types.Boolean() - - -def test_datetime(): - with pytest.raises(Exception): - _sdk_types.Types.Datetime() - - -def test_timedelta(): - with pytest.raises(Exception): - _sdk_types.Types.Timedelta() diff --git a/tests/flytekit/unit/sdk/types/test_schema.py b/tests/flytekit/unit/sdk/types/test_schema.py deleted file mode 100644 index 95f70d6702..0000000000 --- a/tests/flytekit/unit/sdk/types/test_schema.py +++ /dev/null @@ -1,49 +0,0 @@ -import pytest - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.types import Types - - -def test_generic_schema(): - @inputs(a=Types.Schema()) - @outputs(b=Types.Schema()) - @python_task - def fake_task(wf_params, a, b): - pass - - -def test_typed_schema(): - @inputs(a=Types.Schema([("a", Types.Integer), ("b", Types.Integer)])) - @outputs(b=Types.Schema([("a", Types.Integer), ("b", Types.Integer)])) - @python_task - def fake_task(wf_params, a, b): - pass - - -def test_bad_definition(): - with pytest.raises(_user_exceptions.FlyteValueException): - Types.Schema([]) - - -def test_bad_column_types(): - with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([("a", Types.Blob)]) - with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([("a", Types.MultiPartBlob)]) - with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([("a", Types.MultiPartCSV)]) - with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([("a", Types.CSV)]) - with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([("a", Types.Schema())]) - - -def test_create_from_hive_query(): - s, q = Types.Schema().create_from_hive_query("SELECT * FROM table", known_location="s3://somewhere/") - - assert s.mode == "wb" - assert s.local_path is None - assert s.remote_location == "s3://somewhere/" - assert "SELECT * FROM table" in q - assert s.remote_location in q diff --git a/tests/flytekit/unit/tasks/__init__.py b/tests/flytekit/unit/tasks/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/test_plugins.py b/tests/flytekit/unit/test_plugins.py deleted file mode 100644 index c6680e34da..0000000000 --- a/tests/flytekit/unit/test_plugins.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest - -from flytekit import plugins -from flytekit.tools import lazy_loader - - -@pytest.mark.run(order=0) -def test_spark_plugin(): - plugins.pyspark.SparkContext - import pyspark - - assert plugins.pyspark.SparkContext == pyspark.SparkContext - - -@pytest.mark.run(order=1) -def test_schema_plugin(): - plugins.numpy.dtype - plugins.pandas.DataFrame - import numpy - import pandas - - assert plugins.numpy.dtype == numpy.dtype - assert pandas.DataFrame == pandas.DataFrame - - -@pytest.mark.run(order=2) -def test_sidecar_plugin(): - assert isinstance(plugins.k8s.io.api.core.v1.generated_pb2, lazy_loader._LazyLoadModule) - assert isinstance( - plugins.k8s.io.apimachinery.pkg.api.resource.generated_pb2, - lazy_loader._LazyLoadModule, - ) - import k8s.io.api.core.v1.generated_pb2 - import k8s.io.apimachinery.pkg.api.resource.generated_pb2 - - k8s.io.api.core.v1.generated_pb2.Container - k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity - - -@pytest.mark.run(order=2) -def test_hive_sensor_plugin(): - assert isinstance(plugins.hmsclient, lazy_loader._LazyLoadModule) - assert isinstance(plugins.hmsclient.genthrift.hive_metastore.ttypes, lazy_loader._LazyLoadModule) - import hmsclient - import hmsclient.genthrift.hive_metastore.ttypes - - hmsclient.HMSClient - hmsclient.genthrift.hive_metastore.ttypes.NoSuchObjectException diff --git a/tests/flytekit/unit/common_tests/test_translator.py b/tests/flytekit/unit/test_translator.py similarity index 99% rename from tests/flytekit/unit/common_tests/test_translator.py rename to tests/flytekit/unit/test_translator.py index 91b1dd2780..e21353c38e 100644 --- a/tests/flytekit/unit/common_tests/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -2,7 +2,6 @@ from collections import OrderedDict from flytekit import ContainerTask, Resources -from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_task import kwtypes from flytekit.core.context_manager import FastSerializationSettings, Image, ImageConfig @@ -11,6 +10,7 @@ from flytekit.core.task import ReferenceTask, task from flytekit.core.workflow import ReferenceWorkflow, workflow from flytekit.models.core import identifier as identifier_models +from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( diff --git a/tests/flytekit/unit/tools/test_aws.py b/tests/flytekit/unit/tools/test_aws.py deleted file mode 100644 index 93445b0752..0000000000 --- a/tests/flytekit/unit/tools/test_aws.py +++ /dev/null @@ -1,7 +0,0 @@ -from flytekit.interfaces.data.s3.s3proxy import AwsS3Proxy - - -def test_aws_s3_splitting(): - (bucket, key) = AwsS3Proxy._split_s3_path_to_bucket_and_key("s3://bucket/some/key") - assert bucket == "bucket" - assert key == "some/key" diff --git a/tests/flytekit/unit/tools/test_lazy_loader.py b/tests/flytekit/unit/tools/test_lazy_loader.py deleted file mode 100644 index 7801318408..0000000000 --- a/tests/flytekit/unit/tools/test_lazy_loader.py +++ /dev/null @@ -1,14 +0,0 @@ -import pytest -import six - -from flytekit.tools import lazy_loader - - -def test_lazy_loader_error_message(): - lazy_mod = lazy_loader.lazy_load_module("made.up.module") - lazy_loader.LazyLoadPlugin("uninstalled_plugin", [], [lazy_mod]) - with pytest.raises(ImportError) as e: - lazy_mod.some_bad_attr - - assert "uninstalled_plugin" in six.text_type(e.value) - assert "flytekit[all]" in six.text_type(e.value) diff --git a/tests/flytekit/unit/tools/test_module_loader.py b/tests/flytekit/unit/tools/test_module_loader.py index aa7fdd255c..9a568f0260 100644 --- a/tests/flytekit/unit/tools/test_module_loader.py +++ b/tests/flytekit/unit/tools/test_module_loader.py @@ -1,12 +1,12 @@ import os import sys -from flytekit.common import utils as _utils +from flytekit.core import utils from flytekit.tools import module_loader def test_module_loading(): - with _utils.AutoDeletingTempDir("mypackage") as pkg: + with utils.AutoDeletingTempDir("mypackage") as pkg: path = pkg.name # Create directories top_level = os.path.join(path, "top") diff --git a/tests/flytekit/unit/type_engines/__init__.py b/tests/flytekit/unit/type_engines/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/type_engines/default/__init__.py b/tests/flytekit/unit/type_engines/default/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py deleted file mode 100644 index 8157739456..0000000000 --- a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py +++ /dev/null @@ -1,67 +0,0 @@ -import pytest -from flyteidl.core import errors_pb2 as _errors_pb2 - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import proto as _proto -from flytekit.models import literals as _literal_models -from flytekit.models import types as _type_models -from flytekit.type_engines.default import flyte as _flyte_engine - - -def test_proto_from_literal_type(): - sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.BINARY, - metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"}, - ) - ) - - assert sdk_type.pb_type == _errors_pb2.ContainerError - - -def test_generic_proto_from_literal_type(): - sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.STRUCT, - metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"}, - ) - ) - - assert sdk_type.pb_type == _errors_pb2.ContainerError - - -def test_unloadable_module_from_literal_type(): - with pytest.raises(_user_exceptions.FlyteAssertion): - _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.BINARY, - metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2_no_exist.ContainerError"}, - ) - ) - - -def test_unloadable_proto_from_literal_type(): - with pytest.raises(_user_exceptions.FlyteAssertion): - _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.BINARY, - metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerErrorNoExist"}, - ) - ) - - -def test_infer_proto_from_literal(): - sdk_type = _flyte_engine.FlyteDefaultTypeEngine().infer_sdk_type_from_literal( - _literal_models.Literal( - scalar=_literal_models.Scalar( - binary=_literal_models.Binary( - value="", - tag="{}{}".format( - _proto.Protobuf.TAG_PREFIX, - "flyteidl.core.errors_pb2.ContainerError", - ), - ) - ) - ) - ) - assert sdk_type.pb_type == _errors_pb2.ContainerError diff --git a/tests/flytekit/unit/use_scenarios/__init__.py b/tests/flytekit/unit/use_scenarios/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/__init__.py b/tests/flytekit/unit/use_scenarios/unit_testing/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/test_blobs.py b/tests/flytekit/unit/use_scenarios/unit_testing/test_blobs.py deleted file mode 100644 index 0530bdb865..0000000000 --- a/tests/flytekit/unit/use_scenarios/unit_testing/test_blobs.py +++ /dev/null @@ -1,165 +0,0 @@ -from flytekit.common.utils import AutoDeletingTempDir -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.test_utils import flyte_test -from flytekit.sdk.types import Types - - -@flyte_test -def test_create_blob_from_local_path(): - @outputs(a=Types.Blob) - @python_task - def test_create_from_local_path(wf_params, a): - with AutoDeletingTempDir("t") as tmp: - tmp_name = tmp.get_named_tempfile("abc.blob") - with open(tmp_name, "wb") as w: - w.write("Hello world".encode("utf-8")) - a.set(tmp_name) - - out = test_create_from_local_path.unit_test() - assert len(out) == 1 - with out["a"] as r: - assert r.read().decode("utf-8") == "Hello world" - - -@flyte_test -def test_write_blob(): - @outputs(a=Types.Blob) - @python_task - def test_write(wf_params, a): - b = Types.Blob() - with b as w: - w.write("Hello world".encode("utf-8")) - a.set(b) - - out = test_write.unit_test() - assert len(out) == 1 - with out["a"] as r: - assert r.read().decode("utf-8") == "Hello world" - - -@flyte_test -def test_blob_passing(): - @inputs(a=Types.Blob) - @outputs(b=Types.Blob) - @python_task - def test_pass(wf_params, a, b): - b.set(a) - - b = Types.Blob() - with b as w: - w.write("Hello world".encode("utf-8")) - - out = test_pass.unit_test(a=b) - assert len(out) == 1 - with out["b"] as r: - assert r.read().decode("utf-8") == "Hello world" - - out = test_pass.unit_test(a=out["b"]) - assert len(out) == 1 - with out["b"] as r: - assert r.read().decode("utf-8") == "Hello world" - - -@flyte_test -def test_create_multipartblob_from_local_path(): - @outputs(a=Types.MultiPartBlob) - @python_task - def test_create_from_local_path(wf_params, a): - with AutoDeletingTempDir("t") as tmp: - with open(tmp.get_named_tempfile("0"), "wb") as w: - w.write("Hello world".encode("utf-8")) - with open(tmp.get_named_tempfile("1"), "wb") as w: - w.write("Hello world2".encode("utf-8")) - a.set(tmp.name) - - out = test_create_from_local_path.unit_test() - assert len(out) == 1 - with out["a"] as r: - assert len(r) == 2 - assert r[0].read().decode("utf-8") == "Hello world" - assert r[1].read().decode("utf-8") == "Hello world2" - - -@flyte_test -def test_write_multipartblob(): - @outputs(a=Types.MultiPartBlob) - @python_task - def test_write(wf_params, a): - b = Types.MultiPartBlob() - with b.create_part("0") as w: - w.write("Hello world".encode("utf-8")) - with b.create_part("1") as w: - w.write("Hello world2".encode("utf-8")) - a.set(b) - - out = test_write.unit_test() - assert len(out) == 1 - with out["a"] as r: - assert len(r) == 2 - assert r[0].read().decode("utf-8") == "Hello world" - assert r[1].read().decode("utf-8") == "Hello world2" - - -@flyte_test -def test_multipartblob_passing(): - @inputs(a=Types.MultiPartBlob) - @outputs(b=Types.MultiPartBlob) - @python_task - def test_pass(wf_params, a, b): - b.set(a) - - b = Types.MultiPartBlob() - with b.create_part("0") as w: - w.write("Hello world".encode("utf-8")) - with b.create_part("1") as w: - w.write("Hello world2".encode("utf-8")) - - out = test_pass.unit_test(a=b) - assert len(out) == 1 - with out["b"] as r: - assert len(r) == 2 - assert r[0].read().decode("utf-8") == "Hello world" - assert r[1].read().decode("utf-8") == "Hello world2" - - out = test_pass.unit_test(a=out["b"]) - assert len(out) == 1 - with out["b"] as r: - assert len(r) == 2 - assert r[0].read().decode("utf-8") == "Hello world" - assert r[1].read().decode("utf-8") == "Hello world2" - - -@flyte_test -def test_write_csv(): - @outputs(a=Types.CSV) - @python_task - def test_write(wf_params, a): - b = Types.CSV() - with b as w: - w.write("Hello,world,hi") - a.set(b) - - out = test_write.unit_test() - assert len(out) == 1 - with out["a"] as r: - assert r.read() == "Hello,world,hi" - - -@flyte_test -def test_write_multipartcsv(): - @outputs(a=Types.MultiPartCSV) - @python_task - def test_write(wf_params, a): - b = Types.MultiPartCSV() - with b.create_part("0") as w: - w.write("Hello,world,1") - with b.create_part("1") as w: - w.write("Hello,world,2") - a.set(b) - - out = test_write.unit_test() - assert len(out) == 1 - with out["a"] as r: - assert len(r) == 2 - assert r[0].read() == "Hello,world,1" - assert r[1].read() == "Hello,world,2" diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/test_hive_tasks.py b/tests/flytekit/unit/use_scenarios/unit_testing/test_hive_tasks.py deleted file mode 100644 index a137f254c1..0000000000 --- a/tests/flytekit/unit/use_scenarios/unit_testing/test_hive_tasks.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from flytekit.sdk.tasks import hive_task - - -def test_no_queries(): - @hive_task - def test_hive_task(wf_params): - pass - - assert test_hive_task.unit_test() == [] - - -def test_empty_list_queries(): - @hive_task - def test_hive_task(wf_params): - return [] - - assert test_hive_task.unit_test() == [] - - -def test_one_query(): - @hive_task - def test_hive_task(wf_params): - return "abc" - - assert test_hive_task.unit_test() == ["abc"] - - -def test_multiple_queries(): - @hive_task - def test_hive_task(wf_params): - return ["abc", "cde"] - - assert test_hive_task.unit_test() == ["abc", "cde"] - - -def test_raise_exception(): - @hive_task - def test_hive_task(wf_params): - raise FloatingPointError("Floating point error for some reason.") - - with pytest.raises(FloatingPointError): - test_hive_task.unit_test() diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/test_schemas.py b/tests/flytekit/unit/use_scenarios/unit_testing/test_schemas.py deleted file mode 100644 index e7df18e86a..0000000000 --- a/tests/flytekit/unit/use_scenarios/unit_testing/test_schemas.py +++ /dev/null @@ -1,139 +0,0 @@ -import pandas as pd -import pytest - -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.test_utils import flyte_test -from flytekit.sdk.types import Types - - -@flyte_test -def test_generic_schema(): - @inputs(a=Types.Schema()) - @outputs(b=Types.Schema()) - @python_task - def copy_task(wf_params, a, b): - out = Types.Schema()() - with a as r: - with out as w: - for df in r.iter_chunks(): - w.write(df) - b.set(out) - - # Test generic copy and pass through - a = Types.Schema()() - with a as w: - w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) - w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) - - outs = copy_task.unit_test(a=a) - - with outs["b"] as r: - df = r.read() - assert list(df["a"]) == [1, 2, 3] - assert list(df["b"]) == [4.0, 5.0, 6.0] - - df = r.read() - assert list(df["a"]) == [3, 2, 1] - assert list(df["b"]) == [6.0, 5.0, 4.0] - - assert r.read() is None - - # Test typed copy and pass through - a = Types.Schema([("a", Types.Integer), ("b", Types.Float)])() - with a as w: - w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) - w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) - - outs = copy_task.unit_test(a=a) - - with outs["b"] as r: - df = r.read() - assert list(df["a"]) == [1, 2, 3] - assert list(df["b"]) == [4.0, 5.0, 6.0] - - df = r.read() - assert list(df["a"]) == [3, 2, 1] - assert list(df["b"]) == [6.0, 5.0, 4.0] - - assert r.read() is None - - -@flyte_test -def test_typed_schema(): - @inputs(a=Types.Schema([("a", Types.Integer), ("b", Types.Float)])) - @outputs(b=Types.Schema([("a", Types.Integer), ("b", Types.Float)])) - @python_task - def copy_task(wf_params, a, b): - out = Types.Schema([("a", Types.Integer), ("b", Types.Float)])() - with a as r: - with out as w: - for df in r.iter_chunks(): - w.write(df) - b.set(out) - - # Test typed copy and pass through - a = Types.Schema([("a", Types.Integer), ("b", Types.Float)])() - with a as w: - w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) - w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) - - outs = copy_task.unit_test(a=a) - - with outs["b"] as r: - df = r.read() - assert list(df["a"]) == [1, 2, 3] - assert list(df["b"]) == [4.0, 5.0, 6.0] - - df = r.read() - assert list(df["a"]) == [3, 2, 1] - assert list(df["b"]) == [6.0, 5.0, 4.0] - - assert r.read() is None - - # Test untyped failure - a = Types.Schema()() - with a as w: - w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) - w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) - - with pytest.raises(_user_exceptions.FlyteTypeException): - copy_task.unit_test(a=a) - - -@flyte_test -def test_subset_of_columns(): - @outputs(a=Types.Schema([("a", Types.Integer), ("b", Types.String)])) - @python_task() - def source(wf_params, a): - out = Types.Schema([("a", Types.Integer), ("b", Types.String)])() - with out as writer: - writer.write(pd.DataFrame.from_dict({"a": [1, 2, 3, 4, 5], "b": ["a", "b", "c", "d", "e"]})) - a.set(out) - - @inputs(a=Types.Schema([("a", Types.Integer)])) - @python_task() - def sink(wf_params, a): - with a as reader: - df = reader.read(concat=True) - assert len(df.columns.values) == 1 - assert df["a"].tolist() == [1, 2, 3, 4, 5] - - with a as reader: - df = reader.read(truncate_extra_columns=False) - assert df.columns.values.tolist() == ["a", "b"] - assert df["a"].tolist() == [1, 2, 3, 4, 5] - assert df["b"].tolist() == ["a", "b", "c", "d", "e"] - - o = source.unit_test() - sink.unit_test(**o) - - -@flyte_test -def test_no_output_set(): - @outputs(a=Types.Schema()) - @python_task() - def null_set(wf_params, a): - pass - - assert null_set.unit_test()["a"] is None