diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index d38b666aee..f5322003d7 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -6,6 +6,9 @@ on: - master pull_request: +env: + FLYTE_SDK_LOGGING_LEVEL: 10 # debug + jobs: build: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 73c81d2a6a..d2d337772f 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ _build/ docs/source/generated/ .pytest_flyte htmlcov +*.ipynb diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 8089e8bd74..6d0b9c2bf6 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -155,12 +155,15 @@ import sys +import flytekit.configuration.internal + if sys.version_info < (3, 10): from importlib_metadata import entry_points else: from importlib.metadata import entry_points -from flytekit.configuration.sdk import USE_STRUCTURED_DATASET +from flytekit import configuration +from flytekit.configuration import internal as _internal from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes from flytekit.core.checkpointer import Checkpoint @@ -190,7 +193,7 @@ from flytekit.models.types import LiteralType from flytekit.types import directory, file, schema -if USE_STRUCTURED_DATASET.get(): +if _internal.LocalSDK.USE_STRUCTURED_DATASET.read(): from flytekit.types.structured.structured_dataset import ( StructuredDataset, StructuredDatasetFormat, diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 5123e90225..ca9f4eb3b5 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,7 +1,8 @@ import contextlib import datetime as _datetime -import os as _os +import os import pathlib +import tempfile import traceback as _traceback from typing import List, Optional @@ -9,21 +10,13 @@ from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit import PythonFunctionTask -from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration -from flytekit.configuration import internal as _internal_config -from flytekit.configuration import sdk as _sdk_config +from flytekit.configuration import SerializationSettings, StatsConfig +from flytekit.core import SERIALIZED_CONTEXT_ENV_VAR from flytekit.core import constants as _constants from flytekit.core import utils from flytekit.core.base_task import IgnoreOutputs, PythonTask from flytekit.core.checkpointer import SyncCheckpoint -from flytekit.core.context_manager import ( - ExecutionParameters, - ExecutionState, - FlyteContext, - FlyteContextManager, - SerializationSettings, - get_image_config, -) +from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.map_task import MapPythonTask from flytekit.core.promise import VoidPromise @@ -56,10 +49,10 @@ def _compute_array_job_index(): :rtype: int """ offset = 0 - if _os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"): - offset = int(_os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET")) - if _os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"): - return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))) + if os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"): + offset = int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET")) + if os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"): + return offset + int(os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))) return offset @@ -82,7 +75,7 @@ def _dispatch_execute( logger.debug(f"Starting _dispatch_execute for {task_def.name}") try: # Step1 - local_inputs_file = _os.path.join(ctx.execution_state.working_dir, "inputs.pb") + local_inputs_file = os.path.join(ctx.execution_state.working_dir, "inputs.pb") ctx.file_access.get_data(inputs_path, local_inputs_file) input_proto = utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto) @@ -157,23 +150,46 @@ def _dispatch_execute( logger.error("!! End Error Captured by Flyte !!") for k, v in output_file_dict.items(): - utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(ctx.execution_state.engine_dir, k)) + utils.write_proto_to_file(v.to_flyte_idl(), os.path.join(ctx.execution_state.engine_dir, k)) ctx.file_access.put_data(ctx.execution_state.engine_dir, output_prefix, is_multipart=True) logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") logger.debug("Finished _dispatch_execute") +def get_one_of(*args) -> str: + """ + Helper function to iterate through a series of different environment variables. This function exists because for + some settings reference multiple environment variables for legacy reasons. + :param args: List of environment variables to look for. + :return: The first defined value in the environment, or an empty string if nothing is found. + """ + for k in args: + if k in os.environ: + return os.environ[k] + return "" + + @contextlib.contextmanager def setup_execution( raw_output_data_prefix: str, checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, - dynamic_addl_distro: Optional[str] = None, - dynamic_dest_dir: Optional[str] = None, ): - ctx = FlyteContextManager.current_context() + exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ") + exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM") + exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_NAME", "_F_NM") + exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF") + exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP") + + tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ") + tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM") + tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM") + tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V") + compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "") + + ctx = FlyteContextManager.current_context() # Create directories user_workspace_dir = ctx.file_access.get_random_local_directory() logger.info(f"Using user directory {user_workspace_dir}") @@ -187,80 +203,56 @@ def setup_execution( execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( - project=_internal_config.EXECUTION_PROJECT.get(), - domain=_internal_config.EXECUTION_DOMAIN.get(), - name=_internal_config.EXECUTION_NAME.get(), + project=exe_project, + domain=exe_domain, + name=exe_name, ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( + cfg=StatsConfig.auto(), # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp - "{}.{}.{}.user_stats".format( - _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), - _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), - _internal_config.TASK_NAME.get() or _internal_config.NAME.get(), - ), + prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats", tags={ - "exec_project": _internal_config.EXECUTION_PROJECT.get(), - "exec_domain": _internal_config.EXECUTION_DOMAIN.get(), - "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(), - "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(), + "exec_project": exe_project, + "exec_domain": exe_domain, + "exec_workflow": exe_wf, + "exec_launchplan": exe_lp, "api_version": _api_version, }, ), logging=entrypoint_logger, tmp_dir=user_workspace_dir, - raw_output_prefix=ctx.file_access._raw_output_prefix, + raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, ) - # TODO: Remove this check for flytekit 1.0 - if raw_output_data_prefix: - try: - file_access = FileAccessProvider( - local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), - raw_output_prefix=raw_output_data_prefix, - ) - except TypeError: # would be thrown from DataPersistencePlugins.find_plugin - logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") - raise - else: - raise Exception("No raw output prefix detected. Please upgrade your version of Propeller to 0.4.0 or later.") - - with FlyteContextManager.with_context(ctx.with_file_access(file_access)) as ctx: - # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing. - env = { - _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(), - _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(), - } - - serialization_settings = SerializationSettings( - project=_internal_config.TASK_PROJECT.get(), - domain=_internal_config.TASK_DOMAIN.get(), - version=_internal_config.TASK_VERSION.get(), - image_config=get_image_config(), - env=env, + try: + file_access = FileAccessProvider( + local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), + raw_output_prefix=raw_output_data_prefix, ) + except TypeError: # would be thrown from DataPersistencePlugins.find_plugin + logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") + raise - # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin, - # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task - # identifier for t1. - with FlyteContextManager.with_context(ctx.with_serialization_settings(serialization_settings)) as ctx: - # Because execution states do not look up the context chain, it has to be made last - with FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params( - mode=ExecutionState.Mode.TASK_EXECUTION, - user_space_params=execution_parameters, - additional_context={ - "dynamic_addl_distro": dynamic_addl_distro, - "dynamic_dest_dir": dynamic_dest_dir, - }, - ) - ) - ) as ctx: - yield ctx + es = ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + user_space_params=execution_parameters, + ) + cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es) + + if compressed_serialization_settings: + ss = SerializationSettings.from_transport(compressed_serialization_settings) + ssb = ss.new_builder() + ssb.project = exe_project + ssb.domain = exe_domain + ssb.version = tk_version + cb = cb.with_serialization_settings(ssb.build()) + + with FlyteContextManager.with_context(cb) as ctx: + yield ctx def _handle_annotated_task( @@ -285,8 +277,6 @@ def _execute_task( resolver_args: List[str], checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, - dynamic_addl_distro: Optional[str] = None, - dynamic_dest_dir: Optional[str] = None, ): """ This function should be called for new API tasks (those only available in 0.16 and later that leverage Python @@ -314,23 +304,16 @@ def _execute_task( if len(resolver_args) < 1: raise Exception("cannot be <1") - with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - with setup_execution( - raw_output_data_prefix, - checkpoint_path=checkpoint_path, - prev_checkpoint=prev_checkpoint, - dynamic_addl_distro=dynamic_addl_distro, - dynamic_dest_dir=dynamic_dest_dir, - ) as ctx: - resolver_obj = load_object_from_module(resolver) - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if test: - logger.info( - f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}" - ) - return - _handle_annotated_task(ctx, _task_def, inputs, output_prefix) + with setup_execution(raw_output_data_prefix, checkpoint_path, prev_checkpoint) as ctx: + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + if test: + logger.info( + f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}" + ) + return + _handle_annotated_task(ctx, _task_def, inputs, output_prefix) @_scopes.system_entry_point @@ -344,8 +327,6 @@ def _execute_map_task( resolver_args: List[str], checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, - dynamic_addl_distro: Optional[str] = None, - dynamic_dest_dir: Optional[str] = None, ): """ This function should be called by map task and aws-batch task @@ -362,38 +343,31 @@ def _execute_map_task( :param resolver: The task resolver to use. This needs to be loadable directly from importlib (and thus cannot be nested). :param resolver_args: Args that will be passed to the aforementioned resolver's load_task function - :param dynamic_addl_distro: In the case of parent tasks executed using the 'fast' mode this captures where the - compressed code archive has been uploaded. - :param dynamic_dest_dir: In the case of parent tasks executed using the 'fast' mode this captures where compressed - code archives should be installed in the flyte task container. :return: """ if len(resolver_args) < 1: raise Exception(f"Resolver args cannot be <1, got {resolver_args}") - with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - with setup_execution( - raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir - ) as ctx: - resolver_obj = load_object_from_module(resolver) - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if not isinstance(_task_def, PythonFunctionTask): - raise Exception("Map tasks cannot be run with instance tasks.") - map_task = MapPythonTask(_task_def, max_concurrency) - - task_index = _compute_array_job_index() - output_prefix = _os.path.join(output_prefix, str(task_index)) - - if test: - logger.info( - f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} " - f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} " - f"Resolver and args: {resolver} {resolver_args}" - ) - return + with setup_execution(raw_output_data_prefix, checkpoint_path, prev_checkpoint) as ctx: + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + if not isinstance(_task_def, PythonFunctionTask): + raise Exception("Map tasks cannot be run with instance tasks.") + map_task = MapPythonTask(_task_def, max_concurrency) + + task_index = _compute_array_job_index() + output_prefix = os.path.join(output_prefix, str(task_index)) + + if test: + logger.info( + f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} " + f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} " + f"Resolver and args: {resolver} {resolver_args}" + ) + return - _handle_annotated_task(ctx, map_task, inputs, output_prefix) + _handle_annotated_task(ctx, map_task, inputs, output_prefix) def normalize_inputs( @@ -424,8 +398,6 @@ def _pass_through(): @_click.option("--checkpoint-path", required=False) @_click.option("--prev-checkpoint", required=False) @_click.option("--test", is_flag=True) -@_click.option("--dynamic-addl-distro", required=False) -@_click.option("--dynamic-dest-dir", required=False) @_click.option("--resolver", required=False) @_click.argument( "resolver-args", @@ -439,8 +411,6 @@ def execute_task_cmd( test, prev_checkpoint, checkpoint_path, - dynamic_addl_distro, - dynamic_dest_dir, resolver, resolver_args, ): @@ -464,8 +434,6 @@ def execute_task_cmd( test=test, resolver=resolver, resolver_args=resolver_args, - dynamic_addl_distro=dynamic_addl_distro, - dynamic_dest_dir=dynamic_dest_dir, checkpoint_path=checkpoint_path, prev_checkpoint=prev_checkpoint, ) @@ -475,31 +443,19 @@ def execute_task_cmd( @_click.option("--additional-distribution", required=False) @_click.option("--dest-dir", required=False) @_click.argument("task-execute-cmd", nargs=-1, type=_click.UNPROCESSED) -def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd): +def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_execute_cmd: List[str]): """ Downloads a compressed code distribution specified by additional-distribution and then calls the underlying task execute command for the updated code. - :param Text additional_distribution: - :param Text dest_dir: - :param task_execute_cmd: - :return: """ if additional_distribution is not None: if not dest_dir: - dest_dir = _os.getcwd() + dest_dir = os.getcwd() _download_distribution(additional_distribution, dest_dir) # Use the commandline to run the task execute command rather than calling it directly in python code # since the current runtime bytecode references the older user code, rather than the downloaded distribution. - - # Insert the call to fast before the unbounded resolver args - cmd = [] - for arg in task_execute_cmd: - if arg == "--resolver": - cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir]) - cmd.append(arg) - - _os.system(" ".join(cmd)) + os.system(" ".join(task_execute_cmd)) @_pass_through.command("pyflyte-map-execute") @@ -508,8 +464,6 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd): @_click.option("--raw-output-data-prefix", required=False) @_click.option("--max-concurrency", type=int, required=False) @_click.option("--test", is_flag=True) -@_click.option("--dynamic-addl-distro", required=False) -@_click.option("--dynamic-dest-dir", required=False) @_click.option("--resolver", required=True) @_click.option("--checkpoint-path", required=False) @_click.option("--prev-checkpoint", required=False) @@ -524,8 +478,6 @@ def map_execute_task_cmd( raw_output_data_prefix, max_concurrency, test, - dynamic_addl_distro, - dynamic_dest_dir, resolver, resolver_args, prev_checkpoint, @@ -543,8 +495,6 @@ def map_execute_task_cmd( raw_output_data_prefix=raw_output_data_prefix, max_concurrency=max_concurrency, test=test, - dynamic_addl_distro=dynamic_addl_distro, - dynamic_dest_dir=dynamic_dest_dir, resolver=resolver, resolver_args=resolver_args, checkpoint_path=checkpoint_path, diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 9754375dec..94fb5f9ea3 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -5,22 +5,15 @@ import time from typing import Optional +import grpc import requests as _requests from flyteidl.service import admin_pb2_grpc as _admin_service from flyteidl.service import auth_pb2 from flyteidl.service import auth_pb2_grpc as auth_service from google.protobuf.json_format import MessageToJson as _MessageToJson -from grpc import RpcError as _RpcError -from grpc import StatusCode as _GrpcStatusCode -from grpc import insecure_channel as _insecure_channel -from grpc import secure_channel as _secure_channel -from grpc import ssl_channel_credentials as _ssl_channel_credentials from flytekit.clis.auth import credentials as _credentials_access -from flytekit.configuration import creds as creds_config -from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET -from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID -from flytekit.configuration.creds import COMMAND as _COMMAND +from flytekit.configuration import AuthType, PlatformConfig from flytekit.exceptions import user as _user_exceptions from flytekit.exceptions.user import FlyteAuthenticationException from flytekit.loggers import cli_logger @@ -28,119 +21,11 @@ _utf_8 = "utf-8" -def _refresh_credentials_standard(flyte_client: RawSynchronousFlyteClient): - """ - This function is used when the configuration value for AUTH_MODE is set to 'standard'. - This either fetches the existing access token or initiates the flow to request a valid access token and store it. - :param flyte_client: RawSynchronousFlyteClient - :return: - """ - authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None - if not flyte_client.oauth2_metadata or not flyte_client.public_client_config: - raise ValueError( - "Raw Flyte client attempting client credentials flow but no response from Admin detected. " - "Check your Admin server's .well-known endpoints to make sure they're working as expected." - ) - client = _credentials_access.get_client( - redirect_endpoint=flyte_client.public_client_config.redirect_uri, - client_id=flyte_client.public_client_config.client_id, - scopes=flyte_client.public_client_config.scopes, - auth_endpoint=flyte_client.oauth2_metadata.authorization_endpoint, - token_endpoint=flyte_client.oauth2_metadata.token_endpoint, - ) - if client.has_valid_credentials and not flyte_client.check_access_token(client.credentials.access_token): - # When Python starts up, if credentials have been stored in the keyring, then the AuthorizationClient - # will have read them into its _credentials field, but it won't be in the RawSynchronousFlyteClient's - # metadata field yet. Therefore, if there's a mismatch, copy it over. - flyte_client.set_access_token(client.credentials.access_token, authorization_header_key) - # However, after copying over credentials from the AuthorizationClient, we have to clear it to avoid the - # scenario where the stored credentials in the keyring are expired. If that's the case, then we only try - # them once (because client here is a singleton), and the next time, we'll do one of the two other conditions - # below. - client.clear() - return - elif client.can_refresh_token: - client.refresh_access_token() - else: - client.start_authorization_flow() - - flyte_client.set_access_token(client.credentials.access_token, authorization_header_key) - - -def _refresh_credentials_basic(flyte_client: RawSynchronousFlyteClient): - """ - This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler - is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin, - like when waiting for another workflow to complete from within a task). This function uses basic auth, which means - the credentials for basic auth must be present from wherever this code is running. - - :param flyte_client: RawSynchronousFlyteClient - :return: - """ - if not flyte_client.oauth2_metadata or not flyte_client.public_client_config: - raise ValueError( - "Raw Flyte client attempting client credentials flow but no response from Admin detected. " - "Check your Admin server's .well-known endpoints to make sure they're working as expected." - ) - - token_endpoint = flyte_client.oauth2_metadata.token_endpoint - scopes = creds_config.SCOPES.get() or flyte_client.public_client_config.scopes - scopes = ",".join(scopes) - - # Note that unlike the Pkce flow, the client ID does not come from Admin. - client_secret = get_secret() - cli_logger.debug("Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), scopes)) - authorization_header = get_basic_authorization_header(_CLIENT_ID.get(), client_secret) - token, expires_in = get_token(token_endpoint, authorization_header, scopes) - cli_logger.info("Retrieved new token, expires in {}".format(expires_in)) - authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None - flyte_client.set_access_token(token, authorization_header_key) - - -def _refresh_credentials_from_command(flyte_client): - """ - This function is used when the configuration value for AUTH_MODE is set to 'external_process'. - It reads an id token generated by an external process started by running the 'command'. - - :param flyte_client: RawSynchronousFlyteClient - :return: - """ - - command = _COMMAND.get() - cli_logger.debug("Starting external process to generate id token. Command {}".format(command)) - try: - output = subprocess.run(command, capture_output=True, text=True, check=True) - except subprocess.CalledProcessError as e: - cli_logger.error("Failed to generate token from command {}".format(command)) - raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e)) - flyte_client.set_access_token(output.stdout.strip()) - - -def _refresh_credentials_noop(flyte_client): - pass - - -def _get_refresh_handler(auth_mode): - if auth_mode == "standard": - return _refresh_credentials_standard - elif auth_mode == "basic" or auth_mode == "client_credentials": - return _refresh_credentials_basic - elif auth_mode == "external_process": - return _refresh_credentials_from_command - else: - raise ValueError( - "Invalid auth mode [{}] specified. Please update the creds config to use a valid value".format(auth_mode) - ) - - def _handle_rpc_error(retry=False): def decorator(fn): def handler(*args, **kwargs): """ Wraps rpc errors as Flyte exceptions and handles authentication the client. - :param args: - :param kwargs: - :return: """ max_retries = 3 max_wait_time = 1000 @@ -148,23 +33,23 @@ def handler(*args, **kwargs): for i in range(max_retries): try: return fn(*args, **kwargs) - except _RpcError as e: - if e.code() == _GrpcStatusCode.UNAUTHENTICATED: + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAUTHENTICATED: # Always retry auth errors. if i == (max_retries - 1): # Exit the loop and wrap the authentication error. raise _user_exceptions.FlyteAuthenticationException(str(e)) cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n") - refresh_handler_fn = _get_refresh_handler(creds_config.AUTH_MODE.get()) - refresh_handler_fn(args[0]) - # There are two cases that we should throw error immediately - # 1. Entity already exists when we register entity - # 2. Entity not found when we fetch entity - elif e.code() == _GrpcStatusCode.ALREADY_EXISTS: + args[0].refresh_credentials() + elif e.code() == grpc.StatusCode.ALREADY_EXISTS: + # There are two cases that we should throw error immediately + # 1. Entity already exists when we register entity + # 2. Entity not found when we fetch entity raise _user_exceptions.FlyteEntityAlreadyExistsException(e) - elif e.code() == _GrpcStatusCode.NOT_FOUND: + elif e.code() == grpc.StatusCode.NOT_FOUND: raise _user_exceptions.FlyteEntityNotExistException(e) else: + print(e) # No more retries if retry=False or max_retries reached. if (retry is False) or i == (max_retries - 1): raise @@ -183,8 +68,8 @@ def _handle_invalid_create_request(fn): def handler(self, create_request): try: fn(self, create_request) - except _RpcError as e: - if e.code() == _GrpcStatusCode.INVALID_ARGUMENT: + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.INVALID_ARGUMENT: cli_logger.error("Error creating Flyte entity because of invalid arguments. Create request: ") cli_logger.error(_MessageToJson(create_request)) @@ -202,53 +87,60 @@ class RawSynchronousFlyteClient(object): be explicit as opposed to inferred from the environment or a configuration file. """ - def __init__(self, url, insecure=False, credentials=None, options=None, root_cert_file=None): + def __init__(self, cfg: PlatformConfig, **kwargs): """ Initializes a gRPC channel to the given Flyte Admin service. - :param Text url: The URL (including port if necessary) to connect to the appropriate Flyte Admin Service. - :param bool insecure: [Optional] Whether to use an insecure connection, default False - :param Text credentials: [Optional] If provided, a secure channel will be opened with the Flyte Admin Service. - :param dict[Text, Text] options: [Optional] A dict of key-value string pairs for configuring the gRPC core - runtime. - :param root_cert_file: Path to a local certificate file if you want. + Args: + url: The server address. + insecure: if insecure is desired """ - self._channel = None - self._url = url - - if insecure: - self._channel = _insecure_channel(url, options=list((options or {}).items())) + self._cfg = cfg + if cfg.insecure: + self._channel = grpc.insecure_channel(cfg.endpoint, **kwargs) else: - if root_cert_file: - with open(root_cert_file, "rb") as fh: - cert_bytes = fh.read() - channel_creds = _ssl_channel_credentials(root_certificates=cert_bytes) + if "credentials" not in kwargs: + credentials = grpc.ssl_channel_credentials( + root_certificates=kwargs.get("root_certificates", None), + private_key=kwargs.get("private_key", None), + certificate_chain=kwargs.get("certificate_chain", None), + ) else: - channel_creds = _ssl_channel_credentials() - - self._channel = _secure_channel( - url, - credentials or channel_creds, - options=list((options or {}).items()), + credentials = kwargs["credentials"] + self._channel = grpc.secure_channel( + target=cfg.endpoint, + credentials=credentials, + options=kwargs.get("options", None), + compression=kwargs.get("compression", None), ) self._stub = _admin_service.AdminServiceStub(self._channel) self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel) try: resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest()) self._public_client_config = resp - except _RpcError: + except grpc.RpcError: cli_logger.debug("No public client auth config found, skipping.") self._public_client_config = None try: resp = self._auth_stub.GetOAuth2Metadata(auth_pb2.OAuth2MetadataRequest()) self._oauth2_metadata = resp - except _RpcError: + except grpc.RpcError: cli_logger.debug("No OAuth2 Metadata found, skipping.") self._oauth2_metadata = None + cli_logger.info( + f"Flyte Client configured -> {self._cfg.endpoint} in {'insecure' if self._cfg.insecure else 'secure'} mode." + ) # metadata will hold the value of the token to send to the various endpoints. self._metadata = None + @classmethod + def with_root_certificate(cls, cfg: PlatformConfig, root_cert_file: str) -> RawSynchronousFlyteClient: + b = None + with open(root_cert_file, "rb") as fp: + b = fp.read() + return RawSynchronousFlyteClient(cfg, credentials=grpc.ssl_channel_credentials(root_certificates=b)) + @property def public_client_config(self) -> Optional[auth_pb2.PublicClientAuthConfigResponse]: return self._public_client_config @@ -259,7 +151,112 @@ def oauth2_metadata(self) -> Optional[auth_pb2.OAuth2MetadataResponse]: @property def url(self) -> str: - return self._url + return self._cfg.endpoint + + def _refresh_credentials_standard(self): + """ + This function is used when the configuration value for AUTH_MODE is set to 'standard'. + This either fetches the existing access token or initiates the flow to request a valid access token and store it. + :param self: RawSynchronousFlyteClient + :return: + """ + authorization_header_key = self.public_client_config.authorization_metadata_key or None + if not self.oauth2_metadata or not self.public_client_config: + raise ValueError( + "Raw Flyte client attempting client credentials flow but no response from Admin detected. " + "Check your Admin server's .well-known endpoints to make sure they're working as expected." + ) + client = _credentials_access.get_client( + redirect_endpoint=self.public_client_config.redirect_uri, + client_id=self.public_client_config.client_id, + scopes=self.public_client_config.scopes, + auth_endpoint=self.oauth2_metadata.authorization_endpoint, + token_endpoint=self.oauth2_metadata.token_endpoint, + ) + if client.has_valid_credentials and not self.check_access_token(client.credentials.access_token): + # When Python starts up, if credentials have been stored in the keyring, then the AuthorizationClient + # will have read them into its _credentials field, but it won't be in the RawSynchronousFlyteClient's + # metadata field yet. Therefore, if there's a mismatch, copy it over. + self.set_access_token(client.credentials.access_token, authorization_header_key) + # However, after copying over credentials from the AuthorizationClient, we have to clear it to avoid the + # scenario where the stored credentials in the keyring are expired. If that's the case, then we only try + # them once (because client here is a singleton), and the next time, we'll do one of the two other conditions + # below. + client.clear() + return + elif client.can_refresh_token: + client.refresh_access_token() + else: + client.start_authorization_flow() + + self.set_access_token(client.credentials.access_token, authorization_header_key) + + def _refresh_credentials_basic(self): + """ + This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler + is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin, + like when waiting for another workflow to complete from within a task). This function uses basic auth, which means + the credentials for basic auth must be present from wherever this code is running. + + :param self: RawSynchronousFlyteClient + :return: + """ + if not self.oauth2_metadata or not self.public_client_config: + raise ValueError( + "Raw Flyte client attempting client credentials flow but no response from Admin detected. " + "Check your Admin server's .well-known endpoints to make sure they're working as expected." + ) + + token_endpoint = self.oauth2_metadata.token_endpoint + scopes = self._cfg.scopes or self.public_client_config.scopes + scopes = ",".join(scopes) + + # Note that unlike the Pkce flow, the client ID does not come from Admin. + client_secret = self._cfg.client_credentials_secret + if not client_secret: + raise FlyteAuthenticationException("No client credentials secret provided in the config") + cli_logger.debug(f"Basic authorization flow with client id {self._cfg.client_id} scope {scopes}") + authorization_header = get_basic_authorization_header(self._cfg.client_id, client_secret) + token, expires_in = get_token(token_endpoint, authorization_header, scopes) + cli_logger.info("Retrieved new token, expires in {}".format(expires_in)) + authorization_header_key = self.public_client_config.authorization_metadata_key or None + self.set_access_token(token, authorization_header_key) + + def _refresh_credentials_from_command(self): + """ + This function is used when the configuration value for AUTH_MODE is set to 'external_process'. + It reads an id token generated by an external process started by running the 'command'. + + :param self: RawSynchronousFlyteClient + :return: + """ + + command = self._cfg.command + if not command: + raise FlyteAuthenticationException("No command specified in configuration for command authentication") + cli_logger.debug("Starting external process to generate id token. Command {}".format(command)) + try: + output = subprocess.run(command, capture_output=True, text=True, check=True) + except subprocess.CalledProcessError as e: + cli_logger.error("Failed to generate token from command {}".format(command)) + raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e)) + self.set_access_token(output.stdout.strip()) + + def _refresh_credentials_noop(self): + pass + + def refresh_credentials(self): + if self._cfg.auth_mode == AuthType.STANDARD: + return self._refresh_credentials_standard() + elif self._cfg.auth_mode == AuthType.BASIC or self._cfg.auth_mode == AuthType.CLIENT_CREDENTIALS: + return self._refresh_credentials_basic() + elif self._cfg.auth_mode == AuthType.EXTERNAL_PROCESS: + return self._refresh_credentials_from_command() + else: + raise ValueError( + f"Invalid auth mode [{self._cfg.auth_mode}] specified." + f"Please update the creds config to use a valid value" + ) def set_access_token(self, access_token: str, authorization_header_key: Optional[str] = "authorization"): # Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses @@ -829,18 +826,6 @@ def get_token(token_endpoint, authorization_header, scope): return response["access_token"], response["expires_in"] -def get_secret(): - """ - This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION - config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object. - :rtype: Text - """ - secret = _CREDENTIALS_SECRET.get() - if secret: - return secret - raise FlyteAuthenticationException("No secret could be found") - - def get_basic_authorization_header(client_id, client_secret): """ This function transforms the client id and the client secret into a header that conforms with http basic auth. diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 6415e92b15..d4408265ac 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -18,12 +18,9 @@ from google.protobuf.json_format import MessageToJson from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType as _GeneratedProtocolMessageType -from flytekit import __version__ +from flytekit import __version__, configuration from flytekit.clients import friendly as _friendly_client from flytekit.clis.helpers import hydrate_registration_parameters -from flytekit.configuration import auth as _auth_config -from flytekit.configuration import platform as _platform_config -from flytekit.configuration import set_flyte_config_file from flytekit.core import utils from flytekit.core.context_manager import FlyteContextManager from flytekit.exceptions import user as _user_exceptions @@ -92,13 +89,14 @@ def _detect_default_config_file(): config_file = _get_config_file_path() if _get_user_filepath_home() and _os.path.exists(config_file): _click.secho("Using default config file at {}".format(_tt(config_file)), fg="blue") - set_flyte_config_file(config_file_path=config_file) + return config_file else: _click.secho( """Config file not found at default location, relying on environment variables instead. To setup your config file run 'flyte-cli setup-config'""", fg="blue", ) + return None def _get_io_string(literal_map, verbose=False): @@ -272,6 +270,13 @@ def _render_schedule_expr(lp): return "{:30}".format(sched_expr) +def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteClient: + parent_ctx = _click.get_current_context(silent=True) + cfg = parent_ctx.obj["config"] + cfg = cfg.with_parameters(endpoint=host, insecure=insecure) + return _friendly_client.SynchronousFlyteClient(cfg, root_certificates=parent_ctx.obj["cacert"]) + + _PROJECT_FLAGS = ["-p", "--project"] _DOMAIN_FLAGS = ["-d", "--domain"] _NAME_FLAGS = ["-n", "--name"] @@ -488,67 +493,27 @@ def make_context(self, cmd_name, args, parent=None): # Pass the parameters to the subcommand "setup-config", so both of the below commands can work. # flyte-cli -h localhost:30081 -i setup-config # flyte-cli setup-config -h localhost:30081 -i + ctx = super(_FlyteSubCommand, self).make_context(cmd_name, prefix_args + args, parent=parent) + ctx.obj = ctx.obj or {} + ctx.obj["cacert"] = parent.params["cacert"] or None if cmd_name == "setup-config": - ctx = super(_FlyteSubCommand, self).make_context(cmd_name, prefix_args + args, parent=parent) - ctx.obj = ctx.obj or {} - ctx.obj["cacert"] = parent.params["cacert"] or None return ctx config = parent.params["config"] if config is None: # Run this as the module is loading to pick up settings that click can # then use when constructing the commands - _detect_default_config_file() + config = _detect_default_config_file() - else: - config = parent.params["config"] - if _os.path.exists(config): - _click.secho("Using config file at {}".format(_tt(config)), fg="blue") - set_flyte_config_file(config_file_path=config) - else: - _click.secho( - "Config file not found at {}".format(_tt(config)), - fg="blue", - ) - - # These two flags are special in that they are specifiable in both the user's default ~/.flyte/config file, - # and in the flyte-cli command itself, both in the parent-command position (flyte-cli) , and in the - # child-command position (e.g. list-task-names). To get around this, first we read the value of the config - # object, and store it. Later in the file below are options for each of these options, one for the parent - # command, and one for the child command. If not set by the parent, and also not set by the child, - # then the value from the config file is used. - # - # For both host and insecure, command line values will override the setting in ~/.flyte/config file. - # - # The host url option is a required setting, so if missing it will fail, but it may be set in the click command, - # so we don't have to check now. It will be checked later. - _HOST_URL = None - try: - _HOST_URL = _platform_config.URL.get() - except _user_exceptions.FlyteAssertion: - pass - _INSECURE_FLAG = _platform_config.INSECURE.get() - - # This is where we handle the value read from the flyte-cli config file, if any, for the insecure flag. - # Previously we tried putting it into the default into the declaration of the option itself, but in click, it - # appears that flags operate like toggles. If both the default is true and the flag is passed in the command, - # they negate each other and it's as if it's not passed. Here we rectify that. - if _INSECURE_FLAG and _INSECURE_FLAGS[0] not in prefix_args: - prefix_args.append(_INSECURE_FLAGS[0]) - - # Use host url in config file if users don't specify the host url - if _HOST_FLAGS[0] not in prefix_args: - prefix_args.extend([_HOST_FLAGS[0], str(_HOST_URL)]) - ctx = super(_FlyteSubCommand, self).make_context(cmd_name, prefix_args + args, parent=parent) - ctx.obj = ctx.obj or {} - ctx.obj["cacert"] = parent.params["cacert"] or None + print("Config loading") + ctx.obj["config"] = configuration.PlatformConfig.auto(config_file=config) return ctx @_click.option( *_CONFIG_FLAGS, required=False, - type=str, + type=_click.Path(exists=True), default=None, help="[Optional] The filepath to the config file to pass to the sub-command (if applicable)." " If set again in the sub-command, the sub-command's parameter takes precedence.", @@ -650,8 +615,7 @@ def list_task_names(project, domain, host, insecure, token, limit, show_all, sor a specific project and domain. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo("Task Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -693,8 +657,7 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show versions of that particular task (identifiable by {Project, Domain, Name}). """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo("Task Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or "*"))) _click.echo("{:50} {:40}".format("Version", "Urn")) @@ -734,8 +697,7 @@ def get_task(urn, host, insecure): The URN of the versioned task is in the form of ``tsk::::``. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) t = client.get_task(cli_identifiers.Identifier.from_python_std(urn)) _click.echo(_tt(t)) _click.echo("") @@ -762,8 +724,7 @@ def list_workflow_names(project, domain, host, insecure, token, limit, show_all, List the names of the workflows under a scope specified by ``{project, domain}``. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo("Workflow Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -805,8 +766,7 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, versions of that particular workflow (identifiable by ``{project, domain, name}``). """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo("Workflow Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or "*"))) _click.echo("{:50} {:40}".format("Version", "Urn")) @@ -846,8 +806,7 @@ def get_workflow(urn, host, insecure): ``wf::::`` """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo(client.get_workflow(cli_identifiers.Identifier.from_python_std(urn))) # TODO: Print workflow pretty _click.echo("") @@ -874,9 +833,7 @@ def list_launch_plan_names(project, domain, host, insecure, token, limit, show_a List the names of the launch plans under the scope specified by {project, domain}. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - + client = _get_client(host, insecure) _click.echo("Launch Plan Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: wf_ids, next_token = client.list_launch_plan_ids_paginated( @@ -919,9 +876,7 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show _click.echo("Active Launch Plan Found in {}:{}\n".format(_tt(project), _tt(domain))) _click.echo("{:30} {:50} {:80}".format("Schedule", "Version", "Urn")) - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - + client = _get_client(host, insecure) while True: active_lps, next_token = client.list_active_launch_plans_paginated( project, @@ -989,9 +944,7 @@ def list_launch_plan_versions( _click.echo("Launch Plan Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name))) _click.echo("{:50} {:80} {:30} {:15}".format("Version", "Urn", "Schedule", "Schedule State")) - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) - + client = _get_client(host, insecure) while True: lp_list, next_token = client.list_launch_plans_paginated( _common_models.NamedEntityIdentifier(project, domain, name), @@ -1043,8 +996,7 @@ def get_launch_plan(urn, host, insecure): The URN of a launch plan is in the form of ``lp::::`` """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo(_tt(client.get_launch_plan(cli_identifiers.Identifier.from_python_std(urn)))) # TODO: Print launch plan pretty _click.echo("") @@ -1061,8 +1013,7 @@ def get_active_launch_plan(project, domain, name, host, insecure): List the versions of all the launch plans under the scope specified by {project, domain}. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) lp = client.get_active_launch_plan(_common_models.NamedEntityIdentifier(project, domain, name)) _click.echo("Active Launch Plan for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name))) @@ -1077,8 +1028,7 @@ def get_active_launch_plan(project, domain, name, host, insecure): @_optional_urn_option def update_launch_plan(state, host, insecure, urn=None): _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) if urn is None: try: @@ -1130,8 +1080,7 @@ def recover_execution(urn, name, host, insecure): Users should use the get-execution and get-launch-plan commands to ascertain the names of inputs to use. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo("Recovering execution {}\n".format(_tt(urn))) @@ -1168,8 +1117,7 @@ def terminate_execution(host, insecure, cause, urn=None): -u lp:flyteexamples:development:some-execution:abc123 """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo("Killing the following executions:\n") _click.echo("{:100} {:40}".format("Urn", "Cause")) @@ -1221,8 +1169,7 @@ def list_executions(project, domain, host, insecure, token, limit, show_all, fil _click.echo("Executions Found in {}:{}\n".format(_tt(project), _tt(domain))) _click.echo("{:100} {:40} {:10}".format("Urn", "Name", "Status")) - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) while True: exec_ids, next_token = client.list_executions_paginated( @@ -1476,8 +1423,7 @@ def get_execution(urn, host, insecure, show_io, verbose): The URN of an execution is in the form of ``ex:::`` """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) e = client.get_execution(cli_identifiers.WorkflowExecutionIdentifier.from_python_std(urn)) node_execs = _get_all_node_executions(client, workflow_execution_identifier=e.id) _render_node_executions(client, node_execs, show_io, verbose, host, insecure, wf_execution=e) @@ -1491,8 +1437,7 @@ def get_execution(urn, host, insecure, show_io, verbose): @_verbose_option def get_child_executions(urn, host, insecure, show_io, verbose): _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) node_execs = _get_all_node_executions( client, task_execution_identifier=cli_identifiers.TaskExecutionIdentifier.from_python_std(urn), @@ -1512,8 +1457,7 @@ def register_project(identifier, name, description, host, insecure): """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) client.register_project(_Project(identifier, name, description)) _click.echo("Registered project [id: {}, name: {}, description: {}]".format(identifier, name, description)) @@ -1532,8 +1476,7 @@ def list_projects(host, insecure, token, limit, show_all, filter, sort_by): """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _click.echo("Projects Found\n") while True: @@ -1566,8 +1509,7 @@ def archive_project(identifier, host, insecure): """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) client.update_project(_Project.archived_project(identifier)) _click.echo("Archived project [id: {}]".format(identifier)) @@ -1583,8 +1525,7 @@ def activate_project(identifier, host, insecure): """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) client.update_project(_Project.active_project(identifier)) _click.echo("Activated project [id: {}]".format(identifier)) @@ -1662,17 +1603,10 @@ def patch_launch_plan(entity: _GeneratedProtocolMessageType) -> _GeneratedProtoc the flyte config and/or a custom output_location_prefix. """ # entity is of type flyteidl.admin.launch_plan_pb2.LaunchPlanSpec - auth_assumable_iam_role = ( - assumable_iam_role if assumable_iam_role is not None else _auth_config.ASSUMABLE_IAM_ROLE.get() - ) - auth_k8s_service_account = ( - kubernetes_service_account - if kubernetes_service_account is not None - else _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() - ) entity.spec.auth_role.CopyFrom( _AuthRole( - assumable_iam_role=auth_assumable_iam_role, kubernetes_service_account=auth_k8s_service_account + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ).to_flyte_idl(), ) @@ -1680,10 +1614,11 @@ def patch_launch_plan(entity: _GeneratedProtocolMessageType) -> _GeneratedProtoc entity.spec.raw_output_data_config.CopyFrom( _RawOutputDataConfig(output_location_prefix=output_location_prefix).to_flyte_idl() ) - elif _auth_config.RAW_OUTPUT_DATA_PREFIX.get() is not None: - entity.spec.raw_output_data_config.CopyFrom( - _RawOutputDataConfig(output_location_prefix=_auth_config.RAW_OUTPUT_DATA_PREFIX.get()).to_flyte_idl() - ) + + _click.echo( + f"IAM_Role: {assumable_iam_role}, ServiceAccount: {kubernetes_service_account}," + f" OutputLocationPrefix: {output_location_prefix}" + ) return entity @@ -1769,8 +1704,7 @@ def register_files( assumable_iam_role, kubernetes_service_account, output_location_prefix ) } - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _extract_and_register(client, project, domain, version, files, patches) @@ -1895,8 +1829,7 @@ def fast_register_task(entity: _GeneratedProtocolMessageType) -> _GeneratedProto assumable_iam_role, kubernetes_service_account, output_location_prefix ), } - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) _extract_and_register(client, project, domain, version, pb_files, patches) @@ -1914,8 +1847,7 @@ def update_workflow_meta(description, state, host, insecure, project, domain, na Updates a workflow entity under the scope specified by {project, domain, name} across versions. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) if state == "active": state = _named_entity.NamedEntityState.ACTIVE elif state == "archived": @@ -1940,8 +1872,7 @@ def update_task_meta(description, host, insecure, project, domain, name): Updates a task entity under the scope specified by {project, domain, name} across versions. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) client.update_named_entity( _core_identifier.ResourceType.TASK, _named_entity.NamedEntityIdentifier(project, domain, name), @@ -1962,8 +1893,7 @@ def update_launch_plan_meta(description, host, insecure, project, domain, name): Updates a launch plan entity under the scope specified by {project, domain, name} across versions. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) client.update_named_entity( _core_identifier.ResourceType.LAUNCH_PLAN, _named_entity.NamedEntityIdentifier(project, domain, name), @@ -1992,8 +1922,7 @@ def update_cluster_resource_attributes(host, insecure, project, domain, name, at --attributes projectQuotaCpu 1 --attributes projectQuotaMemory 500M """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) cluster_resource_attributes = _ClusterResourceAttributes({attribute[0]: attribute[1] for attribute in attributes}) matching_attributes = _MatchingAttributes(cluster_resource_attributes=cluster_resource_attributes) @@ -2027,8 +1956,7 @@ def update_execution_queue_attributes(host, insecure, project, domain, name, tag --tags critical --tags gpu_intensive """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) execution_queue_attributes = _ExecutionQueueAttributes(list(tags)) matching_attributes = _MatchingAttributes(execution_queue_attributes=execution_queue_attributes) @@ -2062,8 +1990,7 @@ def update_execution_cluster_label(host, insecure, project, domain, name, value) $ flyte-cli -h localhost:30081 -p flyteexamples -d development update-execution-cluster-label --value foo """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) execution_cluster_label = _ExecutionClusterLabel(value) matching_attributes = _MatchingAttributes(execution_cluster_label=execution_cluster_label) @@ -2101,8 +2028,7 @@ def update_plugin_override(host, insecure, project, domain, name, task_type, plu --plugin-id my_cool_plugin --plugin-id my_fallback_plugin --missing-plugin-behavior FAIL """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) plugin_override = _PluginOverride( task_type, list(plugin_id), _PluginOverride.string_to_enum(missing_plugin_behavior.upper()) ) @@ -2146,8 +2072,7 @@ def get_matching_attributes(host, insecure, project, domain, name, resource_type combination. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) if name is not None: attributes = client.get_workflow_attributes( @@ -2183,8 +2108,7 @@ def list_matching_attributes(host, insecure, resource_type): Fetches all matchable resources of the given resource type. """ _welcome_message() - parent_ctx = _click.get_current_context(silent=True) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, root_cert_file=parent_ctx.obj["cacert"]) + client = _get_client(host, insecure) attributes = client.list_matchable_attributes(_MatchableResource.string_to_enum(resource_type.upper())) for configuration in attributes.configurations: @@ -2250,7 +2174,6 @@ def setup_config(host, insecure): # ConfigParser needs all keys to be strings parser.set("credentials", key, str(credentials_config[key])) parser.write(f) - set_flyte_config_file(config_file_path=config_file) _click.secho("Wrote default config file to {}".format(_tt(config_file)), fg="blue") diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index f87a7d2a11..73274e972e 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -6,7 +6,7 @@ from flyteidl.core import identifier_pb2 as _identifier_pb2 from flyteidl.core import workflow_pb2 as _workflow_pb2 -from flytekit.clis.sdk_in_container.serialize import _DOMAIN_PLACEHOLDER, _PROJECT_PLACEHOLDER, _VERSION_PLACEHOLDER +from flytekit.configuration import DOMAIN_PLACEHOLDER, PROJECT_PLACEHOLDER, VERSION_PLACEHOLDER def parse_args_into_dict(input_arguments): @@ -33,13 +33,13 @@ def str2bool(str): def _hydrate_identifier( project: str, domain: str, version: str, identifier: _identifier_pb2.Identifier ) -> _identifier_pb2.Identifier: - if not identifier.project or identifier.project == _PROJECT_PLACEHOLDER: + if not identifier.project or identifier.project == PROJECT_PLACEHOLDER: identifier.project = project - if not identifier.domain or identifier.domain == _DOMAIN_PLACEHOLDER: + if not identifier.domain or identifier.domain == DOMAIN_PLACEHOLDER: identifier.domain = domain - if not identifier.version or identifier.version == _VERSION_PLACEHOLDER: + if not identifier.version or identifier.version == VERSION_PLACEHOLDER: identifier.version = version return identifier diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index bc2ff68666..d0d7f7a229 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -6,6 +6,7 @@ CTX_TEST = "test" CTX_PACKAGES = "pkgs" CTX_NOTIFICATIONS = "notifications" +CTX_CONFIG_FILE = "config_file" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index 6532fff9e5..b10e851449 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -2,46 +2,18 @@ import sys import tarfile import tempfile -import typing import click -from flytekit.clis.sdk_in_container import constants, serialize -from flytekit.configuration import internal +from flytekit.clis.sdk_in_container import constants +from flytekit.configuration import ( + DEFAULT_RUNTIME_PYTHON_INTERPRETER, + FastSerializationSettings, + ImageConfig, + SerializationSettings, +) from flytekit.core import context_manager -from flytekit.core.context_manager import ImageConfig, look_up_image_info -from flytekit.tools import fast_registration, module_loader - -_DEFAULT_IMAGE_NAME = "default" -_DEFAULT_RUNTIME_PYTHON_INTERPRETER = "/opt/venv/bin/python3" - - -def validate_image(ctx: typing.Any, param: str, values: tuple) -> ImageConfig: - """ - Validates the image to match the standard format. Also validates that only one default image - is provided. a default image, is one that is specified as - default=img or just img. All other images should be provided with a name, in the format - name=img - """ - default_image = None - images = [] - for v in values: - if "=" in v: - splits = v.split("=", maxsplit=1) - img = look_up_image_info(name=splits[0], tag=splits[1], optional_tag=False) - else: - img = look_up_image_info(_DEFAULT_IMAGE_NAME, v, False) - - if default_image and img.name == _DEFAULT_IMAGE_NAME: - raise click.BadParameter( - f"Only one default image can be specified. Received multiple {default_image} & {img} for {param}" - ) - if img.name == _DEFAULT_IMAGE_NAME: - default_image = img - else: - images.append(img) - - return ImageConfig(default_image, images) +from flytekit.tools import fast_registration, module_loader, serialize_helpers @click.command("package") @@ -52,7 +24,7 @@ def validate_image(ctx: typing.Any, param: str, values: tuple) -> ImageConfig: required=False, multiple=True, type=click.UNPROCESSED, - callback=validate_image, + callback=ImageConfig.validate_image, help="A fully qualified tag for an docker image, e.g. somedocker.com/myimage:someversion123. This is a " "multi-option and can be of the form --image xyz.io/docker:latest" " --image my_image=xyz.io/docker2:latest. Note, the `name=image_uri`. The name is optional, if not" @@ -95,7 +67,7 @@ def validate_image(ctx: typing.Any, param: str, values: tuple) -> ImageConfig: @click.option( "-p", "--python-interpreter", - default=_DEFAULT_RUNTIME_PYTHON_INTERPRETER, + default=DEFAULT_RUNTIME_PYTHON_INTERPRETER, required=False, help="Use this to override the default location of the in-container python interpreter that will be used by " "Flyte to load your program. This is usually where you install flytekit within the container.", @@ -120,23 +92,13 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_ if os.path.exists(output) and not force: raise click.BadParameter(click.style(f"Output file {output} already exists, specify -f to override.", fg="red")) - env_binary_path = os.path.dirname(python_interpreter) - venv_root = os.path.dirname(env_binary_path) - serialization_settings = context_manager.SerializationSettings( - project=serialize._PROJECT_PLACEHOLDER, - domain=serialize._DOMAIN_PLACEHOLDER, - version=serialize._VERSION_PLACEHOLDER, - fast_serialization_settings=context_manager.FastSerializationSettings( + serialization_settings = SerializationSettings( + image_config=image_config, + fast_serialization_settings=FastSerializationSettings( enabled=fast, destination_dir=in_container_source_path, ), - image_config=image_config, - env={internal.IMAGE.env_var: image_config.default_image.full}, # TODO this env variable should be deprecated - flytekit_virtualenv_root=venv_root, python_interpreter=python_interpreter, - entrypoint_settings=context_manager.EntrypointSettings( - path=os.path.join(venv_root, serialize._DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC) - ), ) pkgs = ctx.obj[constants.CTX_PACKAGES] @@ -151,11 +113,11 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_ click.secho(f"Loading packages {pkgs} under source root {source}", fg="yellow") module_loader.just_load_modules(pkgs=pkgs) - registrable_entities = serialize.get_registrable_entities(ctx) + registrable_entities = serialize_helpers.get_registrable_entities(ctx) if registrable_entities: with tempfile.TemporaryDirectory() as output_tmpdir: - serialize.persist_registrable_entities(registrable_entities, output_tmpdir) + serialize_helpers.persist_registrable_entities(registrable_entities, output_tmpdir) # If Fast serialization is enabled, then an archive is also created and packaged if fast: diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index f370c82942..20893337c9 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,18 +1,12 @@ -import os as _os -from pathlib import Path - import click -from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES +from flytekit import configuration +from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES from flytekit.clis.sdk_in_container.init import init from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package from flytekit.clis.sdk_in_container.serialize import serialize -from flytekit.configuration import platform as _platform_config -from flytekit.configuration import set_flyte_config_file -from flytekit.configuration.internal import CONFIGURATION_PATH -from flytekit.configuration.platform import URL as _URL -from flytekit.configuration.sdk import WORKFLOW_PACKAGES as _WORKFLOW_PACKAGES +from flytekit.configuration.internal import LocalSDK def validate_package(ctx, param, values): @@ -25,13 +19,6 @@ def validate_package(ctx, param, values): @click.group("pyflyte", invoke_without_command=True) -@click.option( - "-c", - "--config", - required=False, - type=str, - help="Path to config file for use within container", -) @click.option( "-k", "--pkgs", @@ -42,60 +29,31 @@ def validate_package(ctx, param, values): "option will override the option specified in the configuration file, or environment variable", ) @click.option( - "-i", - "--insecure", + "-c", + "--config", required=False, - type=bool, - help="Disable SSL when connecting to Flyte backend.", + type=str, + help="Path to config file for use within container", ) @click.pass_context -def main(ctx, config=None, pkgs=None, insecure=None): +def main(ctx, pkgs=None, config=None): """ Entrypoint for all the user commands. """ - update_configuration_file(config) - ctx.obj = dict() - # Determine SSL. Note that the insecure option in this command is not a flag because we want to look - # up configuration settings if it's missing. If the command line option says insecure but the config object - # says no, let's override the config object by overriding the environment variable. - if insecure and not _platform_config.INSECURE.get(): - _platform_config.INSECURE.get() - _os.environ[_platform_config.INSECURE.env_var] = "True" - # Handle package management - get from config if not specified on the command line pkgs = pkgs or [] - if len(pkgs) == 0: - pkgs = _WORKFLOW_PACKAGES.get() + if config: + ctx.obj[CTX_CONFIG_FILE] = config + cfg = configuration.ConfigFile(config) + if not pkgs: + pkgs = LocalSDK.WORKFLOW_PACKAGES.read(cfg) + if pkgs is None: + pkgs = [] ctx.obj[CTX_PACKAGES] = pkgs -def update_configuration_file(config_file_path): - """ - Changes the configuration singleton object to read from another file if specified, which should be - at the base of the repository. - - :param Text config_file_path: - """ - configuration_file = Path(config_file_path or CONFIGURATION_PATH.get()) - if configuration_file.is_file(): - click.secho( - "Using configuration file at {}".format(configuration_file.absolute().as_posix()), - fg="green", - ) - set_flyte_config_file(configuration_file.as_posix()) - else: - click.secho( - "Configuration file '{}' could not be loaded. Using values from environment.".format( - CONFIGURATION_PATH.get() - ), - color="yellow", - ) - set_flyte_config_file(None) - click.secho("Flyte Admin URL {}".format(_URL.get()), fg="green") - - main.add_command(serialize) main.add_command(package) main.add_command(local_cache) diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index 09c06f3d0f..99b25b7e1e 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -1,51 +1,23 @@ -import math as _math -import os as _os +import os import sys import tarfile as _tarfile import typing -from collections import OrderedDict from enum import Enum as _Enum import click -from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan -from flyteidl.admin.task_pb2 import TaskSpec as _idl_admin_TaskSpec -from flyteidl.admin.workflow_pb2 import WorkflowSpec as _idl_admin_WorkflowSpec -import flytekit as _flytekit +from flytekit.clis.sdk_in_container import constants from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES -from flytekit.configuration import internal as _internal_config +from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.core import context_manager as flyte_context -from flytekit.core.base_task import PythonTask -from flytekit.core.launch_plan import LaunchPlan -from flytekit.core.workflow import WorkflowBase from flytekit.exceptions.scopes import system_entry_point -from flytekit.exceptions.user import FlyteValidationException -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import task as task_models -from flytekit.models.admin import workflow as admin_workflow_models -from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import compute_digest as _compute_digest from flytekit.tools.fast_registration import filter_tar_file_fn as _filter_tar_file_fn from flytekit.tools.module_loader import trigger_loading -from flytekit.tools.translator import get_serializable - -# Identifier fields use placeholders for registration-time substitution. -# Additional fields, such as auth and the raw output data prefix have more complex structures -# and can be optional so they are not serialized with placeholders. -_PROJECT_PLACEHOLDER = "{{ registration.project }}" -_DOMAIN_PLACEHOLDER = "{{ registration.domain }}" -_VERSION_PLACEHOLDER = "{{ registration.version }}" - - -# During out of container serialize the absolute path of the flytekit virtualenv at serialization time won't match the -# in-container value at execution time. The following default value is used to provide the in-container virtualenv path -# but can be optionally overridden at serialization time based on the installation of your flytekit virtualenv. -_DEFAULT_FLYTEKIT_VIRTUALENV_ROOT = "/opt/venv/" -_DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC = "bin/entrypoint.py" +from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities CTX_IMAGE = "image" CTX_LOCAL_SRC_ROOT = "local_source_root" -CTX_CONFIG_FILE_LOC = "config_file_loc" CTX_FLYTEKIT_VIRTUALENV_ROOT = "flytekit_virtualenv_root" CTX_PYTHON_INTERPRETER = "python_interpreter" @@ -55,102 +27,16 @@ class SerializationMode(_Enum): FAST = 1 -def _should_register_with_admin(entity) -> bool: - """ - This is used in the code below. The translator.py module produces lots of objects (namely nodes and BranchNodes) - that do not/should not be written to .pb file to send to admin. This function filters them out. - """ - return isinstance( - entity, (task_models.TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec) - ) - - -def _find_duplicate_tasks(tasks: typing.List[task_models.TaskSpec]) -> typing.Set[task_models.TaskSpec]: - """ - Given a list of `TaskSpec`, this function returns a set containing the duplicated `TaskSpec` if any exists. - """ - seen: typing.Set[_identifier.Identifier] = set() - duplicate_tasks: typing.Set[task_models.TaskSpec] = set() - for task in tasks: - if task.template.id not in seen: - seen.add(task.template.id) - else: - duplicate_tasks.add(task) - return duplicate_tasks - - -def get_registrable_entities(ctx: flyte_context.FlyteContext) -> typing.List: - """ - Returns all entities that can be serialized and should be sent over to Flyte backend. This will filter any entities - that are not known to Admin - """ - new_api_serializable_entities = OrderedDict() - # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan - # object, which gets added to the FlyteEntities.entities list, which we're iterating over. - for entity in flyte_context.FlyteEntities.entities.copy(): - if isinstance(entity, PythonTask) or isinstance(entity, WorkflowBase) or isinstance(entity, LaunchPlan): - get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity) - - if isinstance(entity, WorkflowBase): - lp = LaunchPlan.get_default_launch_plan(ctx, entity) - get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp) - - new_api_model_values = list(new_api_serializable_entities.values()) - entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values)) - serializable_tasks: typing.List[task_models.TaskSpec] = [ - entity for entity in entities_to_be_serialized if isinstance(entity, task_models.TaskSpec) - ] - # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same - # metadata identifiers (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate - # tasks are considered invalid at registration - # time and usually indicate user error, so we catch this common mistake at serialization time. - duplicate_tasks = _find_duplicate_tasks(serializable_tasks) - if len(duplicate_tasks) > 0: - duplicate_task_names = [task.template.id.name for task in duplicate_tasks] - raise FlyteValidationException( - f"Multiple definitions of the following tasks were found: {duplicate_task_names}" - ) - - return [v.to_flyte_idl() for v in entities_to_be_serialized] - - -def persist_registrable_entities(entities: typing.List, folder: str): - """ - For protobuf serializable list of entities, writes a file with the name if the entity and - enumeration order to the specified folder - """ - zero_padded_length = _determine_text_chars(len(entities)) - for i, entity in enumerate(entities): - name = "" - fname_index = str(i).zfill(zero_padded_length) - if isinstance(entity, _idl_admin_TaskSpec): - name = entity.template.id.name - fname = "{}_{}_1.pb".format(fname_index, entity.template.id.name) - elif isinstance(entity, _idl_admin_WorkflowSpec): - name = entity.template.id.name - fname = "{}_{}_2.pb".format(fname_index, entity.template.id.name) - elif isinstance(entity, _idl_admin_LaunchPlan): - name = entity.id.name - fname = "{}_{}_3.pb".format(fname_index, entity.id.name) - else: - click.secho(f"Entity is incorrect formatted {entity} - type {type(entity)}", fg="red") - sys.exit(-1) - click.secho(f" Packaging {name} -> {fname}", dim=True) - fname = _os.path.join(folder, fname) - with open(fname, "wb") as writer: - writer.write(entity.SerializeToString()) - - @system_entry_point def serialize_all( pkgs: typing.List[str] = None, - local_source_root: str = None, - folder: str = None, - mode: SerializationMode = None, - image: str = None, - config_path: str = None, - flytekit_virtualenv_root: str = None, - python_interpreter: str = None, + local_source_root: typing.Optional[str] = None, + folder: typing.Optional[str] = None, + mode: typing.Optional[SerializationMode] = None, + image: typing.Optional[str] = None, + flytekit_virtualenv_root: typing.Optional[str] = None, + python_interpreter: typing.Optional[str] = None, + config_file: typing.Optional[str] = None, ): """ This function will write to the folder specified the following protobuf types :: @@ -168,36 +54,22 @@ def serialize_all( :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use - :param config_path: Path to the config file, if any, to be used during serialization :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ - env = { - _internal_config.CONFIGURATION_PATH.env_var: config_path - if config_path - else _internal_config.CONFIGURATION_PATH.get(), - _internal_config.IMAGE.env_var: image, - } - if not (mode == SerializationMode.DEFAULT or mode == SerializationMode.FAST): raise AssertionError(f"Unrecognized serialization mode: {mode}") - fast_serialization_settings = flyte_context.FastSerializationSettings( - enabled=mode == SerializationMode.FAST, - # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here - ) - serialization_settings = flyte_context.SerializationSettings( - project=_PROJECT_PLACEHOLDER, - domain=_DOMAIN_PLACEHOLDER, - version=_VERSION_PLACEHOLDER, - image_config=flyte_context.get_image_config(img_name=image), - env=env, + + serialization_settings = SerializationSettings( + image_config=ImageConfig.auto(config_file, img_name=image), + fast_serialization_settings=FastSerializationSettings( + enabled=mode == SerializationMode.FAST, + # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here + ), flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, - entrypoint_settings=flyte_context.EntrypointSettings( - path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC) - ), - fast_serialization_settings=fast_serialization_settings, ) + ctx = flyte_context.FlyteContextManager.current_context().with_serialization_settings(serialization_settings) with flyte_context.FlyteContextManager.with_context(ctx) as ctx: trigger_loading(pkgs, local_source_root=local_source_root) @@ -210,24 +82,17 @@ def serialize_all( click.secho(f"Successfully serialized {len(loaded_entities)} flyte objects", fg="green") -def _determine_text_chars(length): - """ - This function is used to help prefix files. If there are only 10 entries, then we just need one digit (0-9) to be - the prefix. If there are 11, then we'll need two (00-10). - - :param int length: - :rtype: int - """ - if length == 0: - return 0 - return _math.ceil(_math.log(length, 10)) - - @click.group("serialize") -@click.option("--image", help="Text tag: e.g. somedocker.com/myimage:someversion123", required=False) +@click.option( + "--image", + required=False, + default=lambda: os.environ.get("FLYTE_INTERNAL_IMAGE", ""), + help="Text tag: e.g. somedocker.com/myimage:someversion123", +) @click.option( "--local-source-root", required=False, + default=lambda: os.getcwd(), help="Root dir for python code containing workflow definitions to operate on when not the current working directory" "Optional when running `pyflyte serialize` in out of container mode and your code lies outside of your working directory", ) @@ -241,7 +106,7 @@ def _determine_text_chars(length): @click.option( "--in-container-virtualenv-root", required=False, - help="This is the root of the flytekit virtual env in your container. " + help="DEPRECATED: This flag is ignored! This is the root of the flytekit virtual env in your container. " "The reason it needs to be a separate option is because this pyflyte utility cannot know where flytekit is " "installed inside your container. Required for running `pyflyte serialize` in out of container mode when " "your container installs the flytekit virtualenv outside of the default `/opt/venv`", @@ -255,35 +120,20 @@ def serialize(ctx, image, local_source_root, in_container_config_path, in_contai object contains the WorkflowTemplate, along with the relevant tasks for that workflow. In lieu of Admin, this serialization step will set the URN of the tasks to the fully qualified name of the task function. """ - if not image: - image = _internal_config.IMAGE.get() ctx.obj[CTX_IMAGE] = image - - if local_source_root is None: - local_source_root = _os.getcwd() ctx.obj[CTX_LOCAL_SRC_ROOT] = local_source_root click.echo("Serializing Flyte elements with image {}".format(image)) - ctx.obj[CTX_CONFIG_FILE_LOC] = in_container_config_path - if in_container_config_path is not None: - # We're in the process of an out of container serialize call. - # Set the entrypoint path to the in container default unless a user-specified option exists. - ctx.obj[CTX_FLYTEKIT_VIRTUALENV_ROOT] = ( - in_container_virtualenv_root - if in_container_virtualenv_root is not None - else _DEFAULT_FLYTEKIT_VIRTUALENV_ROOT - ) - - # append python3 - ctx.obj[CTX_PYTHON_INTERPRETER] = ctx.obj[CTX_FLYTEKIT_VIRTUALENV_ROOT] + "/bin/python3" + if in_container_virtualenv_root: + ctx.obj[CTX_FLYTEKIT_VIRTUALENV_ROOT] = in_container_virtualenv_root + ctx.obj[CTX_PYTHON_INTERPRETER] = os.path.join(in_container_virtualenv_root, "/bin/python3") else: # For in container serialize we make sure to never accept an override the entrypoint path and determine it here # instead. - entrypoint_path = _os.path.abspath(_flytekit.__file__) - if entrypoint_path.endswith(".pyc"): - entrypoint_path = entrypoint_path[:-1] + import flytekit - ctx.obj[CTX_FLYTEKIT_VIRTUALENV_ROOT] = _os.path.dirname(entrypoint_path) + flytekit_install_loc = os.path.abspath(flytekit.__file__) + ctx.obj[CTX_FLYTEKIT_VIRTUALENV_ROOT] = os.path.dirname(flytekit_install_loc) ctx.obj[CTX_PYTHON_INTERPRETER] = sys.executable @@ -293,6 +143,7 @@ def serialize(ctx, image, local_source_root, in_container_config_path, in_contai @click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context def workflows(ctx, folder=None): + if folder: click.echo(f"Writing output to {folder}") @@ -304,8 +155,9 @@ def workflows(ctx, folder=None): folder, SerializationMode.DEFAULT, image=ctx.obj[CTX_IMAGE], - config_path=ctx.obj[CTX_CONFIG_FILE_LOC], flytekit_virtualenv_root=ctx.obj[CTX_FLYTEKIT_VIRTUALENV_ROOT], + python_interpreter=ctx.obj[CTX_PYTHON_INTERPRETER], + config_file=ctx.obj.get(constants.CTX_CONFIG_FILE, None), ) @@ -319,13 +171,14 @@ def fast(ctx): @click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context def fast_workflows(ctx, folder=None): + if folder: click.echo(f"Writing output to {folder}") source_dir = ctx.obj[CTX_LOCAL_SRC_ROOT] digest = _compute_digest(source_dir) folder = folder if folder else "" - archive_fname = _os.path.join(folder, f"{digest}.tar.gz") + archive_fname = os.path.join(folder, f"{digest}.tar.gz") click.echo(f"Writing compressed archive to {archive_fname}") # Write using gzip with _tarfile.open(archive_fname, "w:gz") as tar: @@ -339,8 +192,9 @@ def fast_workflows(ctx, folder=None): folder, SerializationMode.FAST, image=ctx.obj[CTX_IMAGE], - config_path=ctx.obj[CTX_CONFIG_FILE_LOC], flytekit_virtualenv_root=ctx.obj[CTX_FLYTEKIT_VIRTUALENV_ROOT], + python_interpreter=ctx.obj[CTX_PYTHON_INTERPRETER], + config_file=ctx.obj.get(constants.CTX_CONFIG_FILE, None), ) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 3f36992aad..af4e0bc5f2 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -1,57 +1,612 @@ -import os as _os -import pathlib as _pathlib +from __future__ import annotations -from flytekit.loggers import logger +import base64 +import datetime +import enum +import gzip +import os +import re +import sys +import tempfile +import typing +from dataclasses import dataclass, field +from typing import Dict, List, Optional +from dataclasses_json import dataclass_json +from docker_image import reference -def set_flyte_config_file(config_file_path): +from flytekit.configuration import internal as _internal +from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, set_if_exists + +PROJECT_PLACEHOLDER = "{{ registration.project }}" +DOMAIN_PLACEHOLDER = "{{ registration.domain }}" +VERSION_PLACEHOLDER = "{{ registration.version }}" +DEFAULT_RUNTIME_PYTHON_INTERPRETER = "/opt/venv/bin/python3" +DEFAULT_FLYTEKIT_ENTRYPOINT_FILELOC = "bin/entrypoint.py" +DEFAULT_IMAGE_NAME = "default" +DEFAULT_IN_CONTAINER_SRC_PATH = "/root" +_IMAGE_FQN_TAG_REGEX = re.compile(r"([^:]+)(?=:.+)?") + + +@dataclass_json +@dataclass(init=True, repr=True, eq=True, frozen=True) +class Image(object): """ - :param Text config_file_path: + Image is a structured wrapper for task container images used in object serialization. + + Attributes: + name (str): A user-provided name to identify this image. + fqn (str): Fully qualified image name. This consists of + #. a registry location + #. a username + #. a repository name + For example: `hostname/username/reponame` + tag (str): Optional tag used to specify which version of an image to pull """ - import flytekit.configuration.common as _common - import flytekit.configuration.internal as _internal - if config_file_path is not None: - original_config_file_path = config_file_path - config_file_path = _os.path.abspath(config_file_path) - if not _pathlib.Path(config_file_path).is_file(): - logger.warning( - f"No config file provided or invalid flyte config_file_path {original_config_file_path} specified." - ) - _os.environ[_internal.CONFIGURATION_PATH.env_var] = config_file_path - elif _internal.CONFIGURATION_PATH.env_var in _os.environ: - logger.debug("Deleting configuration path {} from env".format(_internal.CONFIGURATION_PATH.env_var)) - del _os.environ[_internal.CONFIGURATION_PATH.env_var] - _common.CONFIGURATION_SINGLETON.reset_config(config_file_path) + name: str + fqn: str + tag: str + @property + def full(self) -> str: + """ " + Return the full image name with tag. + """ + return f"{self.fqn}:{self.tag}" -class TemporaryConfiguration(object): - def __init__(self, new_config_path, internal_overrides=None): + @staticmethod + def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image: """ - :param Text new_config_path: + Looks up the image tag from environment variable (should be set from the Dockerfile). + FLYTE_INTERNAL_IMAGE should be the environment variable. + + This function is used when registering tasks/workflows with Admin. + When using the canonical Python-based development cycle, the version that is used to register workflows + and tasks with Admin should be the version of the image itself, which should ideally be something unique + like the sha of the latest commit. + + :param optional_tag: + :param name: + :param Text tag: e.g. somedocker.com/myimage:someversion123 + :rtype: Text """ - import flytekit.configuration.common as _common + ref = reference.Reference.parse(tag) + if not optional_tag and ref["tag"] is None: + raise AssertionError(f"Incorrectly formatted image {tag}, missing tag value") + else: + return Image(name=name, fqn=ref["name"], tag=ref["tag"]) - self._internal_overrides = { - _common.format_section_key("internal", k): v for k, v in (internal_overrides or {}).items() - } - self._new_config_path = new_config_path - self._old_config_path = None - self._old_internals = None - def __enter__(self): - import flytekit.configuration.internal as _internal +@dataclass_json +@dataclass(init=True, repr=True, eq=True, frozen=True) +class ImageConfig(object): + """ + ImageConfig holds available images which can be used at registration time. A default image can be specified + along with optional additional images. Each image in the config must have a unique name. + + Attributes: + default_image (str): The default image to be used as a container for task serialization. + images (List[Image]): Optional, additional images which can be used in task container definitions. + """ + + default_image: Optional[Image] = None + images: Optional[List[Image]] = None - self._old_internals = {k: _os.environ.get(k) for k in self._internal_overrides.keys()} - self._old_config_path = _os.environ.get(_internal.CONFIGURATION_PATH.env_var) - _os.environ.update(self._internal_overrides) - set_flyte_config_file(self._new_config_path) + def find_image(self, name) -> Optional[Image]: + """ + Return an image, by name, if it exists. + """ + lookup_images = [self.default_image] if self.default_image else [] + if self.images: + lookup_images.extend(self.images) + # lookup_images = l + [self.default_image] if self.images else [self.default_image] + for i in lookup_images: + if i.name == name: + return i + return None + + @staticmethod + def validate_image(_: typing.Any, param: str, values: tuple) -> ImageConfig: + """ + Validates the image to match the standard format. Also validates that only one default image + is provided. a default image, is one that is specified as + default=img or just img. All other images should be provided with a name, in the format + name=img + """ + default_image = None + images = [] + for v in values: + if "=" in v: + splits = v.split("=", maxsplit=1) + img = Image.look_up_image_info(name=splits[0], tag=splits[1], optional_tag=False) + else: + img = Image.look_up_image_info(DEFAULT_IMAGE_NAME, v, False) - def __exit__(self, exc_type, exc_val, exc_tb): - for k, v in self._old_internals.items(): - if v is not None: - _os.environ[k] = v + if default_image and img.name == DEFAULT_IMAGE_NAME: + raise ValueError( + f"Only one default image can be specified. Received multiple {default_image} & {img} for {param}" + ) + if img.name == DEFAULT_IMAGE_NAME: + default_image = img else: - _os.environ.pop(k, None) - self._old_internals = None - set_flyte_config_file(self._old_config_path) + images.append(img) + + return ImageConfig(default_image, images) + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None, img_name: Optional[str] = None) -> ImageConfig: + """ + Reads from config file or from img_name + :param config_file: + :param img_name: + :return: + """ + if config_file is None and img_name is None: + raise ValueError("Either an image or a config with a default image should be provided") + + default_img = Image.look_up_image_info("default", img_name) if img_name else None + all_images = [default_img] if default_img else [] + + other_images = [] + if config_file: + config_file = get_config_file(config_file) + other_images = [ + Image.look_up_image_info(k, tag=v, optional_tag=True) + for k, v in _internal.Images.get_specified_images(config_file).items() + ] + all_images.extend(other_images) + return ImageConfig(default_image=default_img, images=all_images) + + +class AuthType(enum.Enum): + STANDARD = "standard" + BASIC = "basic" + CLIENT_CREDENTIALS = "client_credentials" + EXTERNAL_PROCESS = "external_process" + + +@dataclass(init=True, repr=True, eq=True, frozen=True) +class PlatformConfig(object): + endpoint: str = "localhost:30081" + insecure: bool = False + command: typing.Optional[typing.List[str]] = None + """ + This command is executed to return a token using an external process. + """ + client_id: typing.Optional[str] = None + """ + This is the public identifier for the app which handles authorization for a Flyte deployment. + More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. + """ + client_credentials_secret: typing.Optional[str] = None + """ + Used for service auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the + password directly from the environment variable. Note that this is less secure! Please only use this if mounting the + secret as a file is impossible. + """ + scopes: List[str] = field(default_factory=list) + auth_mode: AuthType = AuthType.STANDARD + + def with_parameters( + self, + endpoint: str = "localhost:30081", + insecure: bool = False, + command: typing.Optional[typing.List[str]] = None, + client_id: typing.Optional[str] = None, + client_credentials_secret: typing.Optional[str] = None, + scopes: List[str] = None, + auth_mode: AuthType = AuthType.STANDARD, + ) -> PlatformConfig: + return PlatformConfig( + endpoint=endpoint, + command=command, + client_id=client_id, + client_credentials_secret=client_credentials_secret, + scopes=scopes if scopes else [], + auth_mode=auth_mode, + ) + + @classmethod + def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig: + """ + Reads from Config file, and overrides from Environment variables. Refer to ConfigEntry for details + :param config_file: + :return: + """ + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "insecure", _internal.Platform.INSECURE.read(config_file)) + kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file)) + kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file)) + kwargs = set_if_exists( + kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) + ) + kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) + kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) + kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) + return PlatformConfig(**kwargs) + + @classmethod + def for_endpoint(cls, endpoint: str, insecure: bool = False) -> PlatformConfig: + return PlatformConfig(endpoint=endpoint, insecure=insecure) + + +@dataclass(init=True, repr=True, eq=True, frozen=True) +class StatsConfig(object): + host: str = "localhost" + port: int = 8125 + disabled: bool = False + disabled_tags: bool = False + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> StatsConfig: + """ + Reads from environment variable, followed by ConfigFile provided + :param config_file: + :return: + """ + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "host", _internal.StatsD.HOST.read(config_file)) + kwargs = set_if_exists(kwargs, "port", _internal.StatsD.PORT.read(config_file)) + kwargs = set_if_exists(kwargs, "disabled", _internal.StatsD.DISABLED.read(config_file)) + kwargs = set_if_exists(kwargs, "disabled_tags", _internal.StatsD.DISABLE_TAGS.read(config_file)) + return StatsConfig(**kwargs) + + +@dataclass(init=True, repr=True, eq=True, frozen=True) +class SecretsConfig(object): + env_prefix: str = "_FSEC_" + default_dir: str = os.path.join(os.sep, "etc", "secrets") + file_prefix: str = "" + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> SecretsConfig: + """ + Reads from environment variable or from config file + :param config_file: + :return: + """ + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "env_prefix", _internal.Secrets.ENV_PREFIX.read(config_file)) + kwargs = set_if_exists(kwargs, "default_dir", _internal.Secrets.DEFAULT_DIR.read(config_file)) + kwargs = set_if_exists(kwargs, "file_prefix", _internal.Secrets.FILE_PREFIX.read(config_file)) + return SecretsConfig(**kwargs) + + +@dataclass +class S3Config(object): + """ + S3 specific configuration + """ + + enable_debug: bool = False + endpoint: typing.Optional[str] = None + retries: int = 3 + backoff: datetime.timedelta = datetime.timedelta(seconds=5) + access_key_id: typing.Optional[str] = None + secret_access_key: typing.Optional[str] = None + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> S3Config: + """ + Automatically configure + :param config_file: + :return: Configr + """ + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "enable_debug", _internal.AWS.ENABLE_DEBUG.read(config_file)) + kwargs = set_if_exists(kwargs, "endpoint", _internal.AWS.S3_ENDPOINT.read(config_file)) + kwargs = set_if_exists(kwargs, "retries", _internal.AWS.RETRIES.read(config_file)) + kwargs = set_if_exists(kwargs, "backoff", _internal.AWS.BACKOFF_SECONDS.read(config_file)) + kwargs = set_if_exists(kwargs, "access_key_id", _internal.AWS.S3_ACCESS_KEY_ID.read(config_file)) + kwargs = set_if_exists(kwargs, "secret_access_key", _internal.AWS.S3_SECRET_ACCESS_KEY.read(config_file)) + return S3Config(**kwargs) + + +@dataclass +class GCSConfig(object): + """ + Any GCS specific configuration. + """ + + gsutil_parallelism: bool = False + + @classmethod + def auto(self, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "gsutil_parallelism", _internal.GCP.GSUTIL_PARALLELISM.read(config_file)) + return GCSConfig(**kwargs) + + +@dataclass(init=True, repr=True, eq=True, frozen=True) +class DataConfig(object): + """ + Any data storage specific configuration. Please do not use this to store secrets, in S3 case, as it is used in + Flyte sandbox environment we store the access key id and secret. + All DataPersistence plugins are passed all DataConfig and the plugin should correctly use the right config + """ + + s3: S3Config = S3Config() + gcs: GCSConfig = GCSConfig() + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig: + config_file = get_config_file(config_file) + return DataConfig( + s3=S3Config.auto(config_file), + gcs=GCSConfig.auto(config_file), + ) + + +@dataclass(init=True, repr=True, eq=True, frozen=True) +class Config(object): + """ + This object represents the environment for Flytekit to perform either + 1. Interactive session with Flyte backend + 2. Some parts are required for Serialization, for example Platform Config is not required + 3. Runtime of a task + Args: + entrypoint_settings: EntrypointSettings object for use with Spark tasks. If supplied, this will be + used when serializing Spark tasks, which need to know the path to the flytekit entrypoint.py file, + inside the container. + """ + + platform: PlatformConfig = PlatformConfig() + secrets: SecretsConfig = SecretsConfig() + stats: StatsConfig = StatsConfig() + data_config: DataConfig = DataConfig() + local_sandbox_path: str = tempfile.mkdtemp(prefix="flyte") + + def with_params( + self, + platform: PlatformConfig = None, + secrets: SecretsConfig = None, + stats: StatsConfig = None, + data_config: DataConfig = None, + local_sandbox_path: str = None, + ) -> Config: + return Config( + platform=platform or self.platform, + secrets=secrets or self.secrets, + stats=stats or self.stats, + data_config=data_config or self.data_config, + local_sandbox_path=local_sandbox_path or self.local_sandbox_path, + ) + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> Config: + """ + Automatically constructs the Config Object. The order of precendence is as follows + 1. first try to find any env vars that match the config vars specified in the FLYTE_CONFIG format. + 2. If not found in environment then values ar read from the config file + 3. If not found in the file, then the default values are used. + :param config_file: file path to read the config from, if not specified default locations are searched + :return: Config + """ + config_file = get_config_file(config_file) + kwargs = {} + set_if_exists(kwargs, "local_sandbox_path", _internal.LocalSDK.LOCAL_SANDBOX.read(cfg=config_file)) + return Config( + platform=PlatformConfig.auto(config_file), + secrets=SecretsConfig.auto(config_file), + stats=StatsConfig.auto(config_file), + data_config=DataConfig.auto(config_file), + **kwargs, + ) + + @classmethod + def for_sandbox(cls) -> Config: + """ + Constructs a new Config object specifically to connect to :std:ref:`deploy-sandbox-local`. + If you are using a hosted Sandbox like environment, then you may need to use port-forward or ingress urls + :return: Config + """ + return Config( + platform=PlatformConfig(insecure=True), + data_config=DataConfig( + s3=S3Config(endpoint="localhost:30084", access_key_id="minio", secret_access_key="miniostorage") + ), + ) + + @classmethod + def for_endpoint( + cls, + endpoint: str, + insecure: bool = False, + data_config: typing.Optional[DataConfig] = None, + config_file: typing.Union[str, ConfigFile] = None, + ) -> Config: + """ + Creates an automatic config for the given endpoint and uses the config_file or environment variable for default. + Refer to `Config.auto()` to understand the default bootstrap behavior. + + data_config can be used to configure how data is downloaded or uploaded to a specific Blob storage like S3 / GCS etc. + But, for permissions to a specific backend just use Cloud providers reqcommendation. If using fsspec, then + refer to fsspec documentation + :param endpoint: -> Endpoint where Flyte admin is available + :param insecure: -> if the connection should be insecure, default is secure (SSL ON) + :param data_config: -> Data config, if using specialized connection params like minio etc + :param config_file: -> Optional config file in the flytekit config format. + :return: Config + """ + c = cls.auto(config_file) + return c.with_params(platform=PlatformConfig.for_endpoint(endpoint, insecure), data_config=data_config) + + +@dataclass_json +@dataclass +class EntrypointSettings(object): + """ + This object carries information about the path of the entrypoint command that will be invoked at runtime. + This is where `pyflyte-execute` code can be found. This is useful for cases like pyspark execution. + """ + + path: Optional[str] = None + + +@dataclass_json +@dataclass +class FastSerializationSettings(object): + """ + This object hold information about settings necessary to serialize an object so that it can be fast-registered. + """ + + enabled: bool = False + # This is the location that the code should be copied into. + destination_dir: Optional[str] = None + + # This is the zip file where the new code was uploaded to. + distribution_location: Optional[str] = None + + +@dataclass_json +@dataclass() +class SerializationSettings(object): + """ + These settings are provided while serializing a workflow and task, before registration. This is required to get + runtime information at serialization time, as well as some defaults. + + TODO: ImageConfig, python_interpreter, venv_root, fast_serialization_settings.destination_dir should be combined. + + Attributes: + project (str): The project (if any) with which to register entities under. + domain (str): The domain (if any) with which to register entities under. + version (str): The version (if any) with which to register entities under. + image_config (ImageConfig): The image config used to define task container images. + env (Optional[Dict[str, str]]): Environment variables injected into task container definitions. + flytekit_virtualenv_root (Optional[str]): During out of container serialize the absolute path of the flytekit + virtualenv at serialization time won't match the in-container value at execution time. This optional value + is used to provide the in-container virtualenv path + python_interpreter (Optional[str]): The python executable to use. This is used for spark tasks in out of + container execution. + entrypoint_settings (Optional[EntrypointSettings]): Information about the command, path and version of the + entrypoint program. + fast_serialization_settings (Optional[FastSerializationSettings]): If the code is being serialized so that it + can be fast registered (and thus omit building a Docker image) this object contains additional parameters + for serialization. + """ + + image_config: ImageConfig + project: typing.Optional[str] = None + domain: typing.Optional[str] = None + version: typing.Optional[str] = None + env: Optional[Dict[str, str]] = None + python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER + flytekit_virtualenv_root: Optional[str] = None + fast_serialization_settings: Optional[FastSerializationSettings] = None + + def __post_init__(self): + if self.flytekit_virtualenv_root is None: + self.flytekit_virtualenv_root = self.venv_root_from_interpreter(self.python_interpreter) + + @property + def entrypoint_settings(self) -> EntrypointSettings: + return EntrypointSettings( + path=os.path.join( + SerializationSettings.venv_root_from_interpreter(self.python_interpreter), + DEFAULT_FLYTEKIT_ENTRYPOINT_FILELOC, + ) + ) + + @dataclass + class Builder(object): + project: str + domain: str + version: str + image_config: ImageConfig + env: Optional[Dict[str, str]] = None + flytekit_virtualenv_root: Optional[str] = None + python_interpreter: Optional[str] = None + fast_serialization_settings: Optional[FastSerializationSettings] = None + + def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder: + self.fast_serialization_settings = fss + return self + + def build(self) -> SerializationSettings: + return SerializationSettings( + project=self.project, + domain=self.domain, + version=self.version, + image_config=self.image_config, + env=self.env, + flytekit_virtualenv_root=self.flytekit_virtualenv_root, + python_interpreter=self.python_interpreter, + fast_serialization_settings=self.fast_serialization_settings, + ) + + @classmethod + def from_transport(cls, s: str) -> SerializationSettings: + compressed_val = base64.b64decode(s.encode("utf-8")) + json_str = gzip.decompress(compressed_val).decode("utf-8") + return cls.from_json(json_str) + + @classmethod + def for_image( + cls, + image: str, + version: str, + project: str = "", + domain: str = "", + python_interpreter_path: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER, + ) -> SerializationSettings: + img = ImageConfig(default_image=Image.look_up_image_info(DEFAULT_IMAGE_NAME, tag=image)) + return SerializationSettings( + image_config=img, + project=project, + domain=domain, + version=version, + python_interpreter=python_interpreter_path, + flytekit_virtualenv_root=cls.venv_root_from_interpreter(python_interpreter_path), + ) + + def new_builder(self) -> Builder: + """ + Creates a ``SerializationSettings.Builder`` that copies the existing serialization settings parameters and + allows for customization. + """ + return SerializationSettings.Builder( + project=self.project, + domain=self.domain, + version=self.version, + image_config=self.image_config, + env=self.env, + flytekit_virtualenv_root=self.flytekit_virtualenv_root, + python_interpreter=self.python_interpreter, + fast_serialization_settings=self.fast_serialization_settings, + ) + + def should_fast_serialize(self) -> bool: + """ + Whether or not the serialization settings specify that entities should be serialized for fast registration. + """ + return self.fast_serialization_settings is not None and self.fast_serialization_settings.enabled + + def prepare_for_transport(self) -> str: + json_str = self.to_json() + compressed_value = gzip.compress(json_str.encode("utf-8")) + return base64.b64encode(compressed_value).decode("utf-8") + + @staticmethod + def venv_root_from_interpreter(interpreter_path: str) -> str: + """ + Computes the path of the virtual environment root, based on the passed in python interpreter path + for example /opt/venv/bin/python3 -> /opt/venv + """ + return os.path.dirname(os.path.dirname(interpreter_path)) + + @staticmethod + def default_entrypoint_settings(interpreter_path: str) -> EntrypointSettings: + """ + Assumes the entrypoint is installed in a virtual-environment where the interpreter is + """ + return EntrypointSettings( + path=os.path.join( + SerializationSettings.venv_root_from_interpreter(interpreter_path), DEFAULT_FLYTEKIT_ENTRYPOINT_FILELOC + ) + ) diff --git a/flytekit/configuration/auth.py b/flytekit/configuration/auth.py deleted file mode 100644 index 5eef062244..0000000000 --- a/flytekit/configuration/auth.py +++ /dev/null @@ -1,23 +0,0 @@ -from flytekit.configuration import common as _config_common - -ASSUMABLE_IAM_ROLE = _config_common.FlyteStringConfigurationEntry("auth", "assumable_iam_role", default=None) -""" -This is the role the SDK will use by default to execute workflows. For example, in AWS this should be an IAM role -string. -""" - -KUBERNETES_SERVICE_ACCOUNT = _config_common.FlyteStringConfigurationEntry( - "auth", "kubernetes_service_account", default=None -) -""" -This is the kubernetes service account that will be passed to workflow executions. -""" - -RAW_OUTPUT_DATA_PREFIX = _config_common.FlyteStringConfigurationEntry("auth", "raw_output_data_prefix", default="") -""" -This is not output metadata but rather where users can specify an S3 or gcs path for offloaded data like blobs -and schemas. - -The reason this setting is in this file is because it's inextricably tied to a workflow's role or service account, -since that is what ultimately gives the tasks the ability to write to certain buckets. -""" diff --git a/flytekit/configuration/aws.py b/flytekit/configuration/aws.py deleted file mode 100644 index a6af62bf6f..0000000000 --- a/flytekit/configuration/aws.py +++ /dev/null @@ -1,23 +0,0 @@ -from flytekit.configuration import common as _config_common - -S3_SHARD_FORMATTER = _config_common.FlyteRequiredStringConfigurationEntry("aws", "s3_shard_formatter") - -S3_SHARD_STRING_LENGTH = _config_common.FlyteIntegerConfigurationEntry("aws", "s3_shard_string_length", default=2) - -S3_ENDPOINT = _config_common.FlyteStringConfigurationEntry("aws", "endpoint", default=None) - -S3_ACCESS_KEY_ID = _config_common.FlyteStringConfigurationEntry("aws", "access_key_id", default=None) - -S3_SECRET_ACCESS_KEY = _config_common.FlyteStringConfigurationEntry("aws", "secret_access_key", default=None) - -S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" - -S3_SECRET_ACCESS_KEY_ENV_NAME = "AWS_SECRET_ACCESS_KEY" - -S3_ENDPOINT_ARG_NAME = "--endpoint-url" - -ENABLE_DEBUG = _config_common.FlyteBoolConfigurationEntry("aws", "enable_debug", default=False) - -RETRIES = _config_common.FlyteIntegerConfigurationEntry("aws", "retries", default=3) - -BACKOFF_SECONDS = _config_common.FlyteIntegerConfigurationEntry("aws", "backoff_seconds", default=5) diff --git a/flytekit/configuration/common.py b/flytekit/configuration/common.py deleted file mode 100644 index 6e0b2088a4..0000000000 --- a/flytekit/configuration/common.py +++ /dev/null @@ -1,291 +0,0 @@ -import abc as _abc -import configparser as _configparser -import os as _os - -from flytekit.exceptions import user as _user_exceptions - - -def format_section_key(section, key): - """ - :param Text section: - :param Text key: - :rtype: Text - """ - return "FLYTE_{section}_{key}".format(section=section.upper(), key=key.upper()) - - -class FlyteConfigurationFile(object): - def __init__(self, location=None): - """ - This singleton is initialized on module load with empty location. If pyflyte is called with - a config flag, it'll reload the singleton with the passed config path. - - :param Text location: used to load config from this location. - """ - self._location = None - self._config = None - self.reset_config(location) - - def _load_config(self): - if self._config is None and self._location: - config = _configparser.ConfigParser() - config.read(self._location) - if config.has_section("internal"): - raise _user_exceptions.FlyteAssertion( - "The config file '{}' cannot contain a section for internal " - "only configurations.".format(self._location) - ) - self._config = config - - def get_string(self, section, key, default=None): - """ - :param Text section: - :param Text key: - :param Text default: - :rtype: Text - """ - self._load_config() - if self._config is not None: - try: - return self._config.get(section, key, fallback=default) - except Exception: - pass - return default - - def get_int(self, section, key, default=None): - """ - :param Text section: - :param Text key: - :param int default: - :rtype: int - """ - self._load_config() - if self._config is not None: - try: - return self._config.getint(section, key, fallback=default) - except Exception: - pass - return default - - def get_bool(self, section, key, default=None): - """ - :param Text section: - :param Text key: - :param bool default: - :rtype: bool - """ - self._load_config() - if self._config is not None: - try: - return self._config.getboolean(section, key, fallback=default) - except Exception: - pass - return default - - def reset_config(self, location): - """ - :param Text location: - """ - self._location = location or _os.environ.get("FLYTE_INTERNAL_CONFIGURATION_PATH") - self._config = None - - @property - def config(self) -> _configparser.ConfigParser: - self._load_config() - return self._config - - -class _FlyteConfigurationPatcher(object): - def __init__(self, new_value, config): - """ - :param Text new_value: - :param _FlyteConfigurationEntry config: - """ - self._new_value = new_value - self._config = config - self._old_value = None - - def __enter__(self): - self._old_value = _os.environ.get(self._config.env_var, None) - if self._new_value is not None: - _os.environ[self._config.env_var] = self._new_value - elif self._old_value is not None: - del _os.environ[self._config.env_var] - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._old_value is not None: - _os.environ[self._config.env_var] = self._old_value - else: - del _os.environ[self._config.env_var] - - -def _get_file_contents(location): - """ - This reads an input file, and returns the string contents, and should be used for reading credentials. - This function will also strip newlines. - - :param Text location: The file path holding the client id or secret - :rtype: Text - """ - if _os.path.isfile(location): - with open(location, "r") as f: - return f.read().replace("\n", "") - return None - - -class _FlyteConfigurationEntry(object, metaclass=_abc.ABCMeta): - def __init__(self, section, key, default=None, validator=None, fallback=None): - self._section = section - self._key = key - self._default = default - self._validator = validator - self._fallback = fallback - - @property - def env_var(self): - """ - :rtype: Text - """ - return format_section_key(self._section, self._key) - - @_abc.abstractmethod - def _getter(self): - pass - - def retrieve_value(self): - """ - The logic in this function changes the lookup behavior for all configuration objects before hitting the - configuration file. - - For a given configuration object ('mysection', 'mysetting'), it will now look at this waterfall: - - i.) The environment variable named 'FLYTE_MYSECTION_MYSETTING' - - ii.) The value of the environment variable that is named the value of the environment variable named - 'FLYTE_MYSECTION_MYSETTING'. That is if os.environ['FLYTE_MYSECTION_MYSETTING'] = 'AAA' and - os.environ['AA'] = 'abc', then 'abc' will be the final value. - - iii.) The contents of the file pointed to by the environment variable named 'FLYTE_MYSECTION_MYSETTING', - assuming the value is a file. - - While it is helpful, this pattern does interrupt the manually specified fallback logic, by effective injecting - two more fallbacks behind the scenes. Just keep this in mind as you are using/creating configuration objects. - :rtype: Text - """ - val = _os.environ.get(self.env_var, None) - if val is None: - referenced_env_var = _os.environ.get("{}_FROM_ENV_VAR".format(self.env_var), None) - if referenced_env_var is not None: - val = _os.environ.get(referenced_env_var, None) - if val is None: - referenced_file = _os.environ.get("{}_FROM_FILE".format(self.env_var), None) - if referenced_file is not None: - val = _get_file_contents(referenced_file) - return val - - def get(self): - val = self._getter() - if val is None and self._fallback is not None: - val = self._fallback.get() - if self._validator: - self._validator(val) - return val - - def is_set(self): - val = self._getter() - return val is not None - - def get_patcher(self, value): - return _FlyteConfigurationPatcher(value, self) - - -class _FlyteRequiredConfigurationEntry(_FlyteConfigurationEntry): - def __init__(self, section, key, validator=None): - super(_FlyteRequiredConfigurationEntry, self).__init__(section, key, validator=self._validate_not_null) - self._extra_validator = validator - - def _validate_not_null(self, val): - if val is None: - raise _user_exceptions.FlyteAssertion( - "No configuration set for [{}] {}. This is a required configuration.".format(self._section, self._key) - ) - if self._extra_validator: - self._extra_validator(val) - - -class FlyteStringConfigurationEntry(_FlyteConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - if val is None: - val = CONFIGURATION_SINGLETON.get_string(self._section, self._key, default=self._default) - return val - - -class FlyteIntegerConfigurationEntry(_FlyteConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - if val is None: - val = CONFIGURATION_SINGLETON.get_int(self._section, self._key, default=self._default) - if val is not None: - return int(val) - return None - - -class FlyteBoolConfigurationEntry(_FlyteConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - - if val is None: - return CONFIGURATION_SINGLETON.get_bool(self._section, self._key, default=self._default) - else: - # Because bool('False') is True, compare to the same values that ConfigParser uses - if val.lower() in ["false", "0", "off", "no"]: - return False - return True - - -class FlyteStringListConfigurationEntry(_FlyteConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - if val is None: - val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) - if val is None: - return self._default - return val.split(",") - - -class FlyteRequiredStringConfigurationEntry(_FlyteRequiredConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - if val is None: - val = CONFIGURATION_SINGLETON.get_string(self._section, self._key, default=self._default) - return val - - -class FlyteRequiredIntegerConfigurationEntry(_FlyteRequiredConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - if val is None: - val = CONFIGURATION_SINGLETON.get_int(self._section, self._key, default=self._default) - return int(val) - - -class FlyteRequiredBoolConfigurationEntry(_FlyteRequiredConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - if val is None: - val = CONFIGURATION_SINGLETON.get_bool(self._section, self._key, default=self._default) - return bool(val) - - -class FlyteRequiredStringListConfigurationEntry(_FlyteRequiredConfigurationEntry): - def _getter(self): - val = self.retrieve_value() - if val is None: - val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) - if val is None: - return self._default - return val.split(",") - - -CONFIGURATION_SINGLETON = FlyteConfigurationFile() diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py deleted file mode 100644 index 9f11ac2d2e..0000000000 --- a/flytekit/configuration/creds.py +++ /dev/null @@ -1,31 +0,0 @@ -from flytekit.configuration import common as _config_common - -COMMAND = _config_common.FlyteStringListConfigurationEntry("credentials", "command", default=None) -""" -This command is executed to return a token using an external process. -""" - -CLIENT_ID = _config_common.FlyteStringConfigurationEntry("credentials", "client_id", default=None) -""" -This is the public identifier for the app which handles authorization for a Flyte deployment. -More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. -""" - -CLIENT_CREDENTIALS_SECRET = _config_common.FlyteStringConfigurationEntry("credentials", "client_secret", default=None) -""" -Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the -password directly from the environment variable. Note that this is less secure! Please only use this if mounting the -secret as a file is impossible. -""" - -SCOPES = _config_common.FlyteStringListConfigurationEntry("credentials", "scopes", default=[]) - -AUTH_MODE = _config_common.FlyteStringConfigurationEntry("credentials", "auth_mode", default="standard") -""" -The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: -- 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials - access. -- 'basic' or 'client_credentials' This uses cert-based auth in which the end user enters a client id and a client - secret and public key encryption is used to facilitate authentication. -- None: No auth will be attempted. -""" diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py new file mode 100644 index 0000000000..8fd4d99c01 --- /dev/null +++ b/flytekit/configuration/file.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import configparser +import configparser as _configparser +import os +import typing +from dataclasses import dataclass +from pathlib import Path + +from flytekit.exceptions import user as _user_exceptions +from flytekit.loggers import logger + + +@dataclass +class LegacyConfigEntry(object): + """ + Creates a record for the config entry. contains + Args: + section: section the option should be found unddd + option: the option str to lookup + type_: Expected type of the value + """ + + section: str + option: str + type_: typing.Type = str + + def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]: + """ + Reads the config entry from environment variable, the structure of the env var is current + ``FLYTE_{SECTION}_{OPTION}`` all upper cased. We will change this in the future. + :return: + """ + env = f"FLYTE_{self.section.upper()}_{self.option.upper()}" + v = os.environ.get(env, None) + if v is None: + return None + return transform(v) if transform else v + + def read_from_file( + self, cfg: ConfigFile, transform: typing.Optional[typing.Callable] = None + ) -> typing.Optional[typing.Any]: + if not cfg: + return None + try: + v = cfg.get(self) + return transform(v) if transform else v + except configparser.Error: + pass + return None + + +def bool_transformer(config_val: typing.Any): + if type(config_val) is str: + return True if config_val and not config_val.lower() in ["false", "0", "off", "no"] else False + else: + return config_val + + +@dataclass +class ConfigEntry(object): + """ + A top level Config entry holder, that holds multiple different representations of the config. + Currently only legacy is supported, but more will be added soon + """ + + legacy: LegacyConfigEntry + transform: typing.Optional[typing.Callable[[str], typing.Any]] = None + + legacy_default_transforms = { + bool: bool_transformer, + } + + def __post_init__(self): + if self.legacy: + if not self.transform and self.legacy.type_ in ConfigEntry.legacy_default_transforms: + self.transform = ConfigEntry.legacy_default_transforms[self.legacy.type_] + + def read(self, cfg: typing.Optional[ConfigFile] = None) -> typing.Optional[typing.Any]: + """ + Reads the config Entry from the various sources in the following order, + First try to read from environment, if not then try to read from the given config file + :param cfg: + :return: + """ + from_env = self.legacy.read_from_env(self.transform) + if from_env is None: + return self.legacy.read_from_file(cfg, self.transform) + return from_env + + +class ConfigFile(object): + def __init__(self, location: str): + """ + Load the config from this location + """ + self._location = location + # TODO, we can choose legacy vs other config using the extension. For .yaml, we can use the new config parser + self._legacy_config = self._read_legacy_config(location) + + def _read_legacy_config(self, location: str) -> _configparser.ConfigParser: + c = _configparser.ConfigParser() + c.read(self._location) + if c.has_section("internal"): + raise _user_exceptions.FlyteAssertion( + "The config file '{}' cannot contain a section for internal " "only configurations.".format(location) + ) + return c + + def _get_from_legacy(self, c: LegacyConfigEntry) -> typing.Any: + if issubclass(c.type_, bool): + return self._legacy_config.getboolean(c.section, c.option) + + if issubclass(c.type_, int): + return self._legacy_config.getint(c.section, c.option) + + if issubclass(c.type_, list): + v = self._legacy_config.get(c.section, c.option) + return v.split(",") + + return self._legacy_config.get(c.section, c.option) + + def get(self, c: typing.Union[LegacyConfigEntry]) -> typing.Any: + if isinstance(c, LegacyConfigEntry): + return self._get_from_legacy(c) + raise NotImplemented("Support for other config types besides .ini / .config files not yet supported") + + @property + def legacy_config(self) -> _configparser.ConfigParser: + return self._legacy_config + + +def get_config_file(c: typing.Union[str, ConfigFile, None]) -> typing.Optional[ConfigFile]: + """ + Checks if the given argument is a file or a configFile and returns a loaded configFile else returns None + """ + if c is None: + # See if there's a config file in the current directory where Python is being run from + current_location_config = Path("flytekit.config") + if current_location_config.exists(): + logger.info(f"Using configuration from Python process root {current_location_config.absolute()}") + return ConfigFile(current_location_config.absolute()) + + # If not, see if there's a config in the user's home directory + home_dir_config = Path(Path.home(), ".flyte", "config") # _default_config_file_name in main.py + if home_dir_config.exists(): + logger.info(f"Using configuration from home directory {home_dir_config.absolute()}") + return ConfigFile(home_dir_config.absolute()) + + # If not, then return None and let caller handle + return None + if isinstance(c, str): + return ConfigFile(c) + return c + + +def set_if_exists(d: dict, k: str, v: typing.Any) -> dict: + """ + Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set + and return the updated dictionary. + + .. note:: + + The input dictionary ``d`` will be mutated. + """ + if v: + d[k] = v + return d diff --git a/flytekit/configuration/gcp.py b/flytekit/configuration/gcp.py deleted file mode 100644 index 11e9452083..0000000000 --- a/flytekit/configuration/gcp.py +++ /dev/null @@ -1,4 +0,0 @@ -from flytekit.configuration import common as _config_common - -GCS_PREFIX = _config_common.FlyteRequiredStringConfigurationEntry("gcp", "gcs_prefix") -GSUTIL_PARALLELISM = _config_common.FlyteBoolConfigurationEntry("gcp", "gsutil_parallelism", default=False) diff --git a/flytekit/configuration/images.py b/flytekit/configuration/images.py deleted file mode 100644 index 092ccaaff7..0000000000 --- a/flytekit/configuration/images.py +++ /dev/null @@ -1,30 +0,0 @@ -import configparser -import typing - -from flytekit.configuration import common as _config_common -from flytekit.loggers import logger - - -def get_specified_images() -> typing.Dict[str, str]: - """ - This section should contain options, where the option name is the friendly name of the image and the corresponding - value is actual FQN of the image. Example of how the section is structured - [images] - my_image1=docker.io/flyte:tag - # Note that the tag is optional. If not specified it will be the default version identifier specified - my_image2=docker.io/flyte - - :returns a dictionary of name: image Version is optional - """ - images: typing.Dict[str, str] = {} - if _config_common.CONFIGURATION_SINGLETON.config is None: - return images - try: - image_names = _config_common.CONFIGURATION_SINGLETON.config.options("images") - except configparser.NoSectionError: - logger.info("No images specified, will use the default image") - image_names = None - if image_names: - for i in image_names: - images[str(i)] = _config_common.FlyteStringConfigurationEntry("images", i).get() - return images diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 043778cd4d..e2641bb23c 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -1,60 +1,150 @@ -import re - -from flytekit.configuration import common as _common_config - -IMAGE = _common_config.FlyteStringConfigurationEntry("internal", "image") -# This configuration option specifies the path to the file that holds the configuration options. Don't worry, -# there will not be cycles because the parsing of the configuration file intentionally will not read and settings -# in the [internal] section. -# The default, if you want to use it, should be a file called flytekit.config, located in wherever your python -# interpreter originates. -CONFIGURATION_PATH = _common_config.FlyteStringConfigurationEntry( - "internal", "configuration_path", default="flytekit.config" -) - -# Project, Domain and Version represent the values at registration time. -PROJECT = _common_config.FlyteStringConfigurationEntry("internal", "project", default="") -DOMAIN = _common_config.FlyteStringConfigurationEntry("internal", "domain", default="") -NAME = _common_config.FlyteStringConfigurationEntry("internal", "name", default="") -VERSION = _common_config.FlyteStringConfigurationEntry("internal", "version", default="") - -# Project, Domain and Version represent the values at registration time. -TASK_PROJECT = _common_config.FlyteStringConfigurationEntry("internal", "task_project", default="") -TASK_DOMAIN = _common_config.FlyteStringConfigurationEntry("internal", "task_domain", default="") -TASK_NAME = _common_config.FlyteStringConfigurationEntry("internal", "task_name", default="") -TASK_VERSION = _common_config.FlyteStringConfigurationEntry("internal", "task_version", default="") - -# Execution project and domain represent the values passed by execution engine at runtime. -EXECUTION_PROJECT = _common_config.FlyteStringConfigurationEntry("internal", "execution_project", default="") -EXECUTION_DOMAIN = _common_config.FlyteStringConfigurationEntry("internal", "execution_domain", default="") -EXECUTION_WORKFLOW = _common_config.FlyteStringConfigurationEntry("internal", "execution_workflow", default="") -EXECUTION_LAUNCHPLAN = _common_config.FlyteStringConfigurationEntry("internal", "execution_launchplan", default="") -EXECUTION_NAME = _common_config.FlyteStringConfigurationEntry("internal", "execution_id", default="") - -# This is another layer of logging level, which can be set by propeller, and can override the SDK configuration if -# necessary. (See the sdk.py version of this as well.) -LOGGING_LEVEL = _common_config.FlyteIntegerConfigurationEntry("internal", "logging_level") - -_IMAGE_VERSION_REGEX = ".*:(.+)" - - -def look_up_version_from_image_tag(tag): - """ - Looks up the image tag from environment variable (should be set from the Dockerfile). - FLYTE_INTERNAL_IMAGE should be the environment variable. - - This function is used when registering tasks/workflows with Admin. - When using the canonical Python-based development cycle, the version that is used to register workflows - and tasks with Admin should be the version of the image itself, which should ideally be something unique - like the sha of the latest commit. - - :param Text tag: e.g. somedocker.com/myimage:someversion123 - :rtype: Text - """ - if tag is None or tag == "": - raise Exception("Bad input for image tag {}".format(tag)) - m = re.match(_IMAGE_VERSION_REGEX, tag) - if m is not None: - return m.group(1) - - raise Exception("Could not parse image version from configuration. Did you set it in the" "Dockerfile?") +import configparser +import datetime +import typing + +from flytekit.configuration.file import ConfigEntry, ConfigFile, LegacyConfigEntry + + +class Images(object): + @staticmethod + def get_specified_images(cfg: ConfigFile) -> typing.Dict[str, str]: + """ + This section should contain options, where the option name is the friendly name of the image and the corresponding + value is actual FQN of the image. Example of how the section is structured + [images] + my_image1=docker.io/flyte:tag + # Note that the tag is optional. If not specified it will be the default version identifier specified + my_image2=docker.io/flyte + + :returns a dictionary of name: image Version is optional + """ + images: typing.Dict[str, str] = {} + if cfg is None: + return images + try: + image_names = cfg.legacy_config.options("images") + except configparser.NoSectionError: + image_names = None + if image_names: + for i in image_names: + images[str(i)] = cfg.legacy_config.get("images", i) + return images + + +class AWS(object): + SECTION = "aws" + S3_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "endpoint")) + S3_ACCESS_KEY_ID = ConfigEntry(LegacyConfigEntry(SECTION, "access_key_id")) + S3_SECRET_ACCESS_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "secret_access_key")) + ENABLE_DEBUG = ConfigEntry(LegacyConfigEntry(SECTION, "enable_debug", bool)) + RETRIES = ConfigEntry(LegacyConfigEntry(SECTION, "retries", int)) + BACKOFF_SECONDS = ConfigEntry( + LegacyConfigEntry(SECTION, "backoff_seconds", datetime.timedelta), + transform=lambda x: datetime.timedelta(seconds=int(x)), + ) + + +class GCP(object): + SECTION = "gcp" + GSUTIL_PARALLELISM = ConfigEntry(LegacyConfigEntry(SECTION, "gsutil_parallelism", bool)) + + +class Credentials(object): + SECTION = "credentials" + COMMAND = ConfigEntry(LegacyConfigEntry(SECTION, "command"), list) + """ + This command is executed to return a token using an external process. + """ + + CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id")) + """ + This is the public identifier for the app which handles authorization for a Flyte deployment. + More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. + """ + + CLIENT_CREDENTIALS_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret")) + """ + Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the + password directly from the environment variable. Note that this is less secure! Please only use this if mounting the + secret as a file is impossible. + """ + + SCOPES = ConfigEntry(LegacyConfigEntry(SECTION, "scopes", list)) + + AUTH_MODE = ConfigEntry(LegacyConfigEntry(SECTION, "auth_mode")) + """ + The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: + - 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials + access. + - 'basic' or 'client_credentials' This uses cert-based auth in which the end user enters a client id and a client + secret and public key encryption is used to facilitate authentication. + - None: No auth will be attempted. + """ + + +class Platform(object): + SECTION = "platform" + URL = ConfigEntry(LegacyConfigEntry(SECTION, "url")) + INSECURE = ConfigEntry(LegacyConfigEntry(SECTION, "insecure", bool)) + + +class LocalSDK(object): + SECTION = "sdk" + WORKFLOW_PACKAGES = ConfigEntry(LegacyConfigEntry(SECTION, "workflow_packages", list)) + """ + This is a comma-delimited list of packages that SDK tools will use to discover entities for the purpose of registration + and execution of entities. + """ + + LOCAL_SANDBOX = ConfigEntry(LegacyConfigEntry(SECTION, "local_sandbox")) + """ + This is the path where SDK will place files during local executions and testing. The SDK will not automatically + clean up data in these directories. + """ + + LOGGING_LEVEL = ConfigEntry(LegacyConfigEntry(SECTION, "logging_level", int)) + """ + This is the default logging level for the Python logging library and will be set before user code runs. + Note that this configuration is special in that it is a runtime setting, not a compile time setting. This is the only + runtime option in this file. + + TODO delete the one from internal config + """ + + # Feature Gate + USE_STRUCTURED_DATASET = ConfigEntry(LegacyConfigEntry(SECTION, "use_structured_dataset", bool)) + """ + Note: This gate will be switched to True at some point in the future. Definitely by 1.0, if not v0.31.0. + """ + + +class Secrets(object): + SECTION = "secrets" + # Secrets management + ENV_PREFIX = ConfigEntry(LegacyConfigEntry(SECTION, "env_prefix")) + """ + This is the prefix that will be used to lookup for injected secrets at runtime. This can be overridden to using + FLYTE_SECRETS_ENV_PREFIX variable + """ + + DEFAULT_DIR = ConfigEntry(LegacyConfigEntry(SECTION, "default_dir")) + """ + This is the default directory that will be used to find secrets as individual files under. This can be overridden using + FLYTE_SECRETS_DEFAULT_DIR. + """ + + FILE_PREFIX = ConfigEntry(LegacyConfigEntry(SECTION, "file_prefix")) + """ + This is the prefix for the file in the default dir. + """ + + +class StatsD(object): + SECTION = "secrets" + # StatsD Config flags should ideally be controlled at the platform level and not through flytekit's config file. + # They are meant to allow administrators to control certain behavior according to how the system is configured. + + HOST = ConfigEntry(LegacyConfigEntry(SECTION, "host")) + PORT = ConfigEntry(LegacyConfigEntry(SECTION, "port", int)) + DISABLED = ConfigEntry(LegacyConfigEntry(SECTION, "disabled", bool)) + DISABLE_TAGS = ConfigEntry(LegacyConfigEntry(SECTION, "disable_tags", bool)) diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py deleted file mode 100644 index ee8bfeb895..0000000000 --- a/flytekit/configuration/platform.py +++ /dev/null @@ -1,4 +0,0 @@ -from flytekit.configuration import common as _config_common - -URL = _config_common.FlyteStringConfigurationEntry("platform", "url") -INSECURE = _config_common.FlyteBoolConfigurationEntry("platform", "insecure", default=False) diff --git a/flytekit/configuration/resources.py b/flytekit/configuration/resources.py deleted file mode 100644 index 8d7a0a0313..0000000000 --- a/flytekit/configuration/resources.py +++ /dev/null @@ -1,65 +0,0 @@ -from flytekit.configuration import common as _config_common - -DEFAULT_CPU_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_cpu_limit") -""" -If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes CPU -request/limit format. -""" - -DEFAULT_CPU_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_cpu_request") -""" -If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes -CPU request/limit format. -""" - -DEFAULT_MEMORY_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_memory_limit") -""" -If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes -memory request/limit format. -""" - -DEFAULT_MEMORY_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_memory_request") -""" -If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes -memory request/limit format. -""" - -DEFAULT_GPU_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_gpu_limit") -""" -If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes GPU -request/limit format. -""" - -DEFAULT_GPU_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_gpu_request") -""" -If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes -GPU request/limit format. -""" - -DEFAULT_STORAGE_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_storage_limit") -""" -If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes -storage request/limit format. -""" - -DEFAULT_STORAGE_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_storage_request") -""" -If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes -storage request/limit format. -""" - -DEFAULT_EPHEMERAL_STORAGE_LIMIT = _config_common.FlyteStringConfigurationEntry( - "resources", "default_ephemeral_storage_limit" -) -""" -If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes -ephemeral storage request/limit format. -""" - -DEFAULT_EPHEMERAL_STORAGE_REQUEST = _config_common.FlyteStringConfigurationEntry( - "resources", "default_ephemeral_storage_request" -) -""" -If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes -ephemeral storage request/limit format. -""" diff --git a/flytekit/configuration/sdk.py b/flytekit/configuration/sdk.py deleted file mode 100644 index a142b29546..0000000000 --- a/flytekit/configuration/sdk.py +++ /dev/null @@ -1,38 +0,0 @@ -import tempfile - -from flytekit.configuration import common as _config_common - -WORKFLOW_PACKAGES = _config_common.FlyteStringListConfigurationEntry("sdk", "workflow_packages", default=[]) -""" -This is a comma-delimited list of packages that SDK tools will use to discover entities for the purpose of registration -and execution of entities. -""" - -LOCAL_SANDBOX = _config_common.FlyteStringConfigurationEntry( - "sdk", - "local_sandbox", - default=tempfile.mkdtemp(prefix="flyte"), -) -""" -This is the path where SDK will place files during local executions and testing. The SDK will not automatically -clean up data in these directories. -""" - -LOGGING_LEVEL = _config_common.FlyteIntegerConfigurationEntry("sdk", "logging_level", default=20) -""" -This is the default logging level for the Python logging library and will be set before user code runs. -Note that this configuration is special in that it is a runtime setting, not a compile time setting. This is the only -runtime option in this file. -""" - -PARQUET_ENGINE = _config_common.FlyteStringConfigurationEntry("sdk", "parquet_engine", default="pyarrow") -""" -This is the parquet engine to use when reading data from parquet files. -""" - -# Feature Gate -USE_STRUCTURED_DATASET = _config_common.FlyteBoolConfigurationEntry("sdk", "use_structured_dataset", default=False) -""" -Note: This gate will be switched to True at some point in the future. Definitely by 1.0, if not v0.31.0. - -""" diff --git a/flytekit/configuration/secrets.py b/flytekit/configuration/secrets.py deleted file mode 100644 index 46cb82f835..0000000000 --- a/flytekit/configuration/secrets.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from flytekit.configuration import common as _common_config - -# Secrets management -SECRETS_ENV_PREFIX = _common_config.FlyteStringConfigurationEntry("secrets", "env_prefix", default="_FSEC_") -""" -This is the prefix that will be used to lookup for injected secrets at runtime. This can be overridden to using -FLYTE_SECRETS_ENV_PREFIX variable -""" - -SECRETS_DEFAULT_DIR = _common_config.FlyteStringConfigurationEntry( - "secrets", "default_dir", default=os.path.join(os.sep, "etc", "secrets") -) -""" -This is the default directory that will be used to find secrets as individual files under. This can be overridden using -FLYTE_SECRETS_DEFAULT_DIR. -""" - -SECRETS_FILE_PREFIX = _common_config.FlyteStringConfigurationEntry("secrets", "file_prefix", default="") -""" -This is the prefix for the file in the default dir. -""" diff --git a/flytekit/configuration/statsd.py b/flytekit/configuration/statsd.py deleted file mode 100644 index d3d7b9852f..0000000000 --- a/flytekit/configuration/statsd.py +++ /dev/null @@ -1,9 +0,0 @@ -from flytekit.configuration import common as _common_config - -# StatsD Config flags should ideally be controlled at the platform level and not through flytekit's config file. -# They are meant to allow administrators to control certain behavior according to how the system is configured. - -HOST = _common_config.FlyteStringConfigurationEntry("statsd", "host", default="localhost") -PORT = _common_config.FlyteIntegerConfigurationEntry("statsd", "port", default=8125) -DISABLED = _common_config.FlyteBoolConfigurationEntry("statsd", "disabled", default=False) -DISABLE_TAGS = _common_config.FlyteBoolConfigurationEntry("statsd", "disable_tags", default=False) diff --git a/flytekit/core/__init__.py b/flytekit/core/__init__.py index e69de29bb2..08df61969c 100644 --- a/flytekit/core/__init__.py +++ b/flytekit/core/__init__.py @@ -0,0 +1 @@ +SERIALIZED_CONTEXT_ENV_VAR = "_F_SS_C" diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 0ca49a9169..bcf9d0bfc7 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -23,13 +23,8 @@ from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union -from flytekit.core.context_manager import ( - ExecutionParameters, - FlyteContext, - FlyteContextManager, - FlyteEntities, - SerializationSettings, -) +from flytekit.configuration import SerializationSettings +from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.local_cache import LocalTaskCache from flytekit.core.promise import ( @@ -290,13 +285,13 @@ def __call__(self, *args, **kwargs): def compile(self, ctx: FlyteContext, *args, **kwargs): raise Exception("not implemented") - def get_container(self, settings: SerializationSettings) -> _task_model.Container: + def get_container(self, settings: SerializationSettings) -> Optional[_task_model.Container]: """ Returns the container definition (if any) that is used to run the task on hosted Flyte. """ return None - def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + def get_k8s_pod(self, settings: SerializationSettings) -> Optional[_task_model.K8sPod]: """ Returns the kubernetes pod definition (if any) that is used to run the task on hosted Flyte. """ @@ -308,13 +303,13 @@ def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: """ return None - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: """ Return additional plugin-specific custom data (if any) as a serializable dictionary. """ return None - def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: """ Returns the task config as a serializable dictionary. This task config consists of metadata about the custom defined for this task. diff --git a/flytekit/core/class_based_resolver.py b/flytekit/core/class_based_resolver.py index 33addbe598..d47820f811 100644 --- a/flytekit/core/class_based_resolver.py +++ b/flytekit/core/class_based_resolver.py @@ -1,7 +1,7 @@ from typing import List +from flytekit.configuration import SerializationSettings from flytekit.core.base_task import TaskResolverMixin -from flytekit.core.context_manager import SerializationSettings from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.tracker import TrackedInstance diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index d46057c623..a31ad8150b 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,8 +1,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Type +from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata -from flytekit.core.context_manager import SerializationSettings from flytekit.core.interface import Interface from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.utils import _get_container_definition diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index e1f77cb123..54017e32b9 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -17,7 +17,6 @@ import logging as _logging import os import pathlib -import re import tempfile import traceback import typing @@ -27,12 +26,8 @@ from enum import Enum from typing import Any, Dict, Generator, List, Optional, Union -from docker_image import reference - from flytekit.clients import friendly as friendly_client # noqa -from flytekit.configuration import images, internal -from flytekit.configuration import sdk as _sdk_config -from flytekit.configuration import secrets +from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider @@ -48,92 +43,13 @@ if typing.TYPE_CHECKING: from flytekit.core.base_task import TaskResolverMixin -_DEFAULT_FLYTEKIT_ENTRYPOINT_FILELOC = "bin/entrypoint.py" - - -@dataclass(init=True, repr=True, eq=True, frozen=True) -class Image(object): - """ - Image is a structured wrapper for task container images used in object serialization. - - Attributes: - name (str): A user-provided name to identify this image. - fqn (str): Fully qualified image name. This consists of - #. a registry location - #. a username - #. a repository name - For example: `hostname/username/reponame` - tag (str): Optional tag used to specify which version of an image to pull - """ - - name: str - fqn: str - tag: str +# Identifier fields use placeholders for registration-time substitution. +# Additional fields, such as auth and the raw output data prefix have more complex structures +# and can be optional so they are not serialized with placeholders. - @property - def full(self) -> str: - """ " - Return the full image name with tag. - """ - return f"{self.fqn}:{self.tag}" - - -@dataclass(init=True, repr=True, eq=True, frozen=True) -class ImageConfig(object): - """ - ImageConfig holds available images which can be used at registration time. A default image can be specified - along with optional additional images. Each image in the config must have a unique name. - - Attributes: - default_image (str): The default image to be used as a container for task serialization. - images (List[Image]): Optional, additional images which can be used in task container definitions. - """ - - default_image: Optional[Image] = None - images: Optional[List[Image]] = None - - def find_image(self, name) -> Optional[Image]: - """ - Return an image, by name, if it exists. - """ - lookup_images = self.images + [self.default_image] if self.images else [self.default_image] - for i in lookup_images: - if i.name == name: - return i - return None - - -_IMAGE_FQN_TAG_REGEX = re.compile(r"([^:]+)(?=:.+)?") - - -def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image: - """ - Looks up the image tag from environment variable (should be set from the Dockerfile). - FLYTE_INTERNAL_IMAGE should be the environment variable. - - This function is used when registering tasks/workflows with Admin. - When using the canonical Python-based development cycle, the version that is used to register workflows - and tasks with Admin should be the version of the image itself, which should ideally be something unique - like the sha of the latest commit. - - :param optional_tag: - :param name: - :param Text tag: e.g. somedocker.com/myimage:someversion123 - :rtype: Text - """ - ref = reference.Reference.parse(tag) - if not optional_tag and ref["tag"] is None: - raise AssertionError(f"Incorrectly formatted image {tag}, missing tag value") - else: - return Image(name=name, fqn=ref["name"], tag=ref["tag"]) - - -def get_image_config(img_name: Optional[str] = None) -> ImageConfig: - image_name = img_name if img_name else internal.IMAGE.get() - default_img = look_up_image_info("default", image_name) if image_name is not None and image_name != "" else None - other_images = [look_up_image_info(k, tag=v, optional_tag=True) for k, v in images.get_specified_images().items()] - other_images.append(default_img) - return ImageConfig(default_image=default_img, images=other_images) +# During out of container serialize the absolute path of the flytekit virtualenv at serialization time won't match the +# in-container value at execution time. The following default value is used to provide the in-container virtualenv path +# but can be optionally overridden at serialization time based on the installation of your flytekit virtualenv. class ExecutionParameters(object): @@ -353,10 +269,12 @@ def __getattr__(self, item: str) -> str: """ return self._sm.get(self._group, item) - def __init__(self): - self._base_dir = str(secrets.SECRETS_DEFAULT_DIR.get()).strip() - self._file_prefix = str(secrets.SECRETS_FILE_PREFIX.get()).strip() - self._env_prefix = str(secrets.SECRETS_ENV_PREFIX.get()).strip() + def __init__(self, secrets_cfg: typing.Optional[SecretsConfig] = None): + if secrets_cfg is None: + secrets_cfg = SecretsConfig.auto() + self._base_dir = secrets_cfg.default_dir.strip() + self._file_prefix = secrets_cfg.file_prefix.strip() + self._env_prefix = secrets_cfg.env_prefix.strip() def __getattr__(self, item: str) -> _GroupSecrets: """ @@ -403,119 +321,6 @@ def check_group_key(group: str, key: str): raise ValueError("secrets key is a mandatory field.") -@dataclass -class EntrypointSettings(object): - """ - This object carries information about the command, path and version of the entrypoint program that will be invoked - to execute tasks at runtime. - """ - - path: Optional[str] = None - command: Optional[str] = None - version: int = 0 - - -@dataclass -class FastSerializationSettings(object): - """ - This object hold information about settings necessary to serialize an object so that it can be fast-registered. - """ - - enabled: bool = False - # This is the location that the code should be copied into. - destination_dir: Optional[str] = None - - # This is the zip file where the new code was uploaded to. - distribution_location: Optional[str] = None - - -@dataclass(frozen=True) -class SerializationSettings(object): - """ - These settings are provided while serializing a workflow and task, before registration. This is required to get - runtime information at serialization time, as well as some defaults. - - Attributes: - project (str): The project (if any) with which to register entities under. - domain (str): The domain (if any) with which to register entities under. - version (str): The version (if any) with which to register entities under. - image_config (ImageConfig): The image config used to define task container images. - env (Optional[Dict[str, str]]): Environment variables injected into task container definitions. - flytekit_virtualenv_root (Optional[str]): During out of container serialize the absolute path of the flytekit - virtualenv at serialization time won't match the in-container value at execution time. This optional value - is used to provide the in-container virtualenv path - python_interpreter (Optional[str]): The python executable to use. This is used for spark tasks in out of - container execution. - entrypoint_settings (Optional[EntrypointSettings]): Information about the command, path and version of the - entrypoint program. - fast_serialization_settings (Optional[FastSerializationSettings]): If the code is being serialized so that it - can be fast registered (and thus omit building a Docker image) this object contains additional parameters - for serialization. - """ - - project: str - domain: str - version: str - image_config: ImageConfig - env: Optional[Dict[str, str]] = None - flytekit_virtualenv_root: Optional[str] = None - python_interpreter: Optional[str] = None - entrypoint_settings: Optional[EntrypointSettings] = None - fast_serialization_settings: Optional[FastSerializationSettings] = None - - @dataclass - class Builder(object): - project: str - domain: str - version: str - image_config: ImageConfig - env: Optional[Dict[str, str]] = None - flytekit_virtualenv_root: Optional[str] = None - python_interpreter: Optional[str] = None - entrypoint_settings: Optional[EntrypointSettings] = None - fast_serialization_settings: Optional[FastSerializationSettings] = None - - def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder: - self.fast_serialization_settings = fss - return self - - def build(self) -> SerializationSettings: - return SerializationSettings( - project=self.project, - domain=self.domain, - version=self.version, - image_config=self.image_config, - env=self.env, - flytekit_virtualenv_root=self.flytekit_virtualenv_root, - python_interpreter=self.python_interpreter, - entrypoint_settings=self.entrypoint_settings, - fast_serialization_settings=self.fast_serialization_settings, - ) - - def new_builder(self) -> Builder: - """ - Creates a ``SerializationSettings.Builder`` that copies the existing serialization settings parameters and - allows for customization. - """ - return SerializationSettings.Builder( - project=self.project, - domain=self.domain, - version=self.version, - image_config=self.image_config, - env=self.env, - flytekit_virtualenv_root=self.flytekit_virtualenv_root, - python_interpreter=self.python_interpreter, - entrypoint_settings=self.entrypoint_settings, - fast_serialization_settings=self.fast_serialization_settings, - ) - - def should_fast_serialize(self) -> bool: - """ - Whether or not the serialization settings specify that entities should be serialized for fast registration. - """ - return self.fast_serialization_settings is not None and self.fast_serialization_settings.enabled - - @dataclass(frozen=True) class CompilationState(object): """ @@ -586,8 +391,6 @@ class ExecutionState(object): working_dir (os.PathLike): Specifies the remote, external directory where inputs, outputs and other protobufs are uploaded engine_dir (os.PathLike): - additional_context Optional[Dict[Any, Any]]: Free form dictionary used to store additional values, for example - those used for dynamic, fast registration. branch_eval_mode Optional[BranchEvalMode]: Used to determine whether a branch node should execute. user_space_params Optional[ExecutionParameters]: Provides run-time, user-centric context such as a statsd handler, a logging handler, the current execution id and a working directory. @@ -617,7 +420,6 @@ class Mode(Enum): mode: Optional[ExecutionState.Mode] working_dir: os.PathLike engine_dir: Optional[Union[os.PathLike, str]] - additional_context: Optional[Dict[Any, Any]] branch_eval_mode: Optional[BranchEvalMode] user_space_params: Optional[ExecutionParameters] @@ -626,7 +428,6 @@ def __init__( working_dir: os.PathLike, mode: Optional[ExecutionState.Mode] = None, engine_dir: Optional[Union[os.PathLike, str]] = None, - additional_context: Optional[Dict[Any, Any]] = None, branch_eval_mode: Optional[BranchEvalMode] = None, user_space_params: Optional[ExecutionParameters] = None, ): @@ -636,7 +437,6 @@ def __init__( self.mode = mode self.engine_dir = engine_dir if engine_dir else os.path.join(self.working_dir, "engine_dir") pathlib.Path(self.engine_dir).mkdir(parents=True, exist_ok=True) - self.additional_context = additional_context self.branch_eval_mode = branch_eval_mode self.user_space_params = user_space_params @@ -659,24 +459,16 @@ def with_params( working_dir: Optional[os.PathLike] = None, mode: Optional[Mode] = None, engine_dir: Optional[os.PathLike] = None, - additional_context: Optional[Dict[Any, Any]] = None, branch_eval_mode: Optional[BranchEvalMode] = None, user_space_params: Optional[ExecutionParameters] = None, ) -> ExecutionState: """ Produces a copy of the current execution state and overrides the copy's parameters with passed parameter values. """ - if self.additional_context: - if additional_context: - additional_context = {**self.additional_context, **additional_context} - else: - additional_context = self.additional_context - return ExecutionState( working_dir=working_dir if working_dir else self.working_dir, mode=mode if mode else self.mode, engine_dir=engine_dir if engine_dir else self.engine_dir, - additional_context=additional_context, branch_eval_mode=branch_eval_mode if branch_eval_mode else self.branch_eval_mode, user_space_params=user_space_params if user_space_params else self.user_space_params, ) @@ -952,8 +744,9 @@ def initialize(): # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally default_execution_id = _identifier.WorkflowExecutionIdentifier(project="local", domain="local", name="local") + cfg = Config.auto() # Ensure a local directory is available for users to work with. - user_space_path = os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space") + user_space_path = os.path.join(cfg.local_sandbox_path, "user_space") pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True) # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 00c233dd8b..57931ff301 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -32,6 +32,7 @@ from typing import Dict, Union from uuid import UUID +from flytekit.configuration import DataConfig from flytekit.core.utils import PerformanceTimer from flytekit.exceptions.user import FlyteAssertion from flytekit.interfaces.random import random @@ -283,7 +284,12 @@ class FileAccessProvider(object): durable store. """ - def __init__(self, local_sandbox_dir: Union[str, os.PathLike], raw_output_prefix: str): + def __init__( + self, + local_sandbox_dir: Union[str, os.PathLike], + raw_output_prefix: str, + data_config: typing.Optional[DataConfig] = None, + ): """ Args: local_sandbox_dir: A local temporary working directory, that should be used to store data @@ -296,7 +302,9 @@ def __init__(self, local_sandbox_dir: Union[str, os.PathLike], raw_output_prefix self._local_sandbox_dir.mkdir(parents=True, exist_ok=True) self._local = DiskPersistence(default_prefix=local_sandbox_dir_appended) - self._default_remote = DataPersistencePlugins.find_plugin(raw_output_prefix)(default_prefix=raw_output_prefix) + self._default_remote = DataPersistencePlugins.find_plugin(raw_output_prefix)( + default_prefix=raw_output_prefix, data_config=data_config + ) self._raw_output_prefix = raw_output_prefix @staticmethod @@ -427,5 +435,7 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul # TODO make this use tmpdir tmp_dir = os.path.join("/tmp/flyte", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) default_local_file_access_provider = FileAccessProvider( - local_sandbox_dir=os.path.join(tmp_dir, "sandbox"), raw_output_prefix=os.path.join(tmp_dir, "raw") + local_sandbox_dir=os.path.join(tmp_dir, "sandbox"), + raw_output_prefix=os.path.join(tmp_dir, "raw"), + data_config=DataConfig.auto(), ) diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 217db39652..c21247305a 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -300,6 +300,33 @@ def __init__( FlyteEntities.entities.append(self) + def clone_with( + self, + name: str, + parameters: _interface_models.ParameterMap = None, + fixed_inputs: _literal_models.LiteralMap = None, + schedule: _schedule_model.Schedule = None, + notifications: List[_common_models.Notification] = None, + labels: _common_models.Labels = None, + annotations: _common_models.Annotations = None, + raw_output_data_config: _common_models.RawOutputDataConfig = None, + auth_role: _common_models.AuthRole = None, + max_parallelism: int = None, + ) -> LaunchPlan: + return LaunchPlan( + name=name, + workflow=self.workflow, + parameters=parameters or self.parameters, + fixed_inputs=fixed_inputs or self.fixed_inputs, + schedule=schedule or self.schedule, + notifications=notifications or self.notifications, + labels=labels or self.labels, + annotations=annotations or self.annotations, + raw_output_data_config=raw_output_data_config or self.raw_output_data_config, + auth_role=auth_role or self._auth_role, + max_parallelism=max_parallelism or self.max_parallelism, + ) + @property def python_interface(self) -> Interface: return self.workflow.python_interface diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 731be53ba8..f95b0f66d0 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -8,9 +8,10 @@ from itertools import count from typing import Any, Dict, List, Optional, Type +from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.constants import SdkTaskType -from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, SerializationSettings +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask from flytekit.exceptions import scopes as exception_scopes diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index a182ead2cc..7ecc5e205c 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -786,7 +786,8 @@ def create_and_link_node_from_remote( for k in sorted(typed_interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: - raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + # TODO to improve the error message, should we show python equivalent types for var.type? + raise _user_exceptions.FlyteAssertion("Missing input `{}` type `{}`".format(k, var.type)) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index c5f8413fea..35800c3c87 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -5,8 +5,9 @@ from abc import ABC from typing import Callable, Dict, List, Optional, TypeVar +from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskResolverMixin -from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index c5a716c3cb..4474994fb9 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -5,8 +5,9 @@ from flyteidl.core import tasks_pb2 as _tasks_pb2 +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin -from flytekit.core.context_manager import FlyteContext, Image, ImageConfig, SerializationSettings +from flytekit.core.context_manager import FlyteContext from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor from flytekit.core.tracker import TrackedInstance diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 2e102c8b62..17a99c1bd3 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -20,7 +20,7 @@ from typing import Any, Callable, List, Optional, TypeVar, Union from flytekit.core.base_task import Task, TaskResolverMixin -from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring from flytekit.core.interface import transform_function_to_interface from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver @@ -233,17 +233,6 @@ def compile_into_workflow( # DynamicJobSpec later tts.append(model.template) - if ctx.serialization_settings.should_fast_serialize(): - if ( - not ctx.execution_state - or not ctx.execution_state.additional_context - or not ctx.execution_state.additional_context.get("dynamic_addl_distro") - ): - raise AssertionError( - "Compilation for a dynamic workflow called in fast execution mode but no additional code " - "distribution could be retrieved" - ) - dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(workflow_spec.template.nodes), tasks=tts, @@ -275,24 +264,6 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: return exception_scopes.user_entry_point(task_function)(**kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: - is_fast_execution = bool( - ctx.execution_state - and ctx.execution_state.additional_context - and ctx.execution_state.additional_context.get("dynamic_addl_distro") - ) - if is_fast_execution: - ctx = ctx.with_serialization_settings( - ctx.serialization_settings.new_builder() - .with_fast_serialization_settings( - FastSerializationSettings( - enabled=True, - destination_dir=ctx.execution_state.additional_context.get("dynamic_dest_dir", "."), - distribution_location=ctx.execution_state.additional_context.get("dynamic_addl_distro"), - ) - ) - .build() - ) - return self.compile_into_workflow(ctx, task_function, **kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION: diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index 7f87c4de91..d23aae3fbb 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Dict, List, Optional -from flytekit.configuration import resources as _resource_config from flytekit.loggers import logger from flytekit.models import task as _task_models @@ -66,16 +65,16 @@ def _get_container_definition( memory_limit: Optional[str] = None, environment: Optional[Dict[str, str]] = None, ) -> _task_models.Container: - storage_limit = storage_limit or _resource_config.DEFAULT_STORAGE_LIMIT.get() - storage_request = storage_request or _resource_config.DEFAULT_STORAGE_REQUEST.get() - ephemeral_storage_limit = ephemeral_storage_limit or _resource_config.DEFAULT_EPHEMERAL_STORAGE_LIMIT.get() - ephemeral_storage_request = ephemeral_storage_request or _resource_config.DEFAULT_EPHEMERAL_STORAGE_REQUEST.get() - cpu_limit = cpu_limit or _resource_config.DEFAULT_CPU_LIMIT.get() - cpu_request = cpu_request or _resource_config.DEFAULT_CPU_REQUEST.get() - gpu_limit = gpu_limit or _resource_config.DEFAULT_GPU_LIMIT.get() - gpu_request = gpu_request or _resource_config.DEFAULT_GPU_REQUEST.get() - memory_limit = memory_limit or _resource_config.DEFAULT_MEMORY_LIMIT.get() - memory_request = memory_request or _resource_config.DEFAULT_MEMORY_REQUEST.get() + storage_limit = storage_limit + storage_request = storage_request + ephemeral_storage_limit = ephemeral_storage_limit + ephemeral_storage_request = ephemeral_storage_request + cpu_limit = cpu_limit + cpu_request = cpu_request + gpu_limit = gpu_limit + gpu_request = gpu_request + memory_limit = memory_limit + memory_request = memory_request requests = [] if storage_request: diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py index d420310fa2..f6635a4a57 100644 --- a/flytekit/extend/__init__.py +++ b/flytekit/extend/__init__.py @@ -33,11 +33,12 @@ DataPersistencePlugins """ +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver -from flytekit.core.context_manager import ExecutionState, Image, ImageConfig, SecretsManager, SerializationSettings +from flytekit.core.context_manager import ExecutionState, SecretsManager from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.core.interface import Interface from flytekit.core.promise import Promise diff --git a/flytekit/extras/cloud_pickle_resolver.py b/flytekit/extras/cloud_pickle_resolver.py index 99ca5438c4..94a7a34d83 100644 --- a/flytekit/extras/cloud_pickle_resolver.py +++ b/flytekit/extras/cloud_pickle_resolver.py @@ -3,8 +3,8 @@ import cloudpickle +from flytekit.configuration import SerializationSettings from flytekit.core.base_task import TaskResolverMixin -from flytekit.core.context_manager import SerializationSettings from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.tracker import TrackedInstance diff --git a/flytekit/extras/persistence/gcs_gsutil.py b/flytekit/extras/persistence/gcs_gsutil.py index 7e7711d64a..86353f3bf9 100644 --- a/flytekit/extras/persistence/gcs_gsutil.py +++ b/flytekit/extras/persistence/gcs_gsutil.py @@ -2,7 +2,7 @@ import typing from shutil import which as shell_which -from flytekit.configuration import gcp +from flytekit.configuration import DataConfig, GCSConfig from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.exceptions.user import FlyteUserException from flytekit.tools import subprocess @@ -32,8 +32,9 @@ class GCSPersistence(DataPersistence): _GS_UTIL_CLI = "gsutil" PROTOCOL = "gs://" - def __init__(self, default_prefix: typing.Optional[str] = None): + def __init__(self, default_prefix: typing.Optional[str] = None, data_config: typing.Optional[DataConfig] = None): super(GCSPersistence, self).__init__(name="gcs-gsutil", default_prefix=default_prefix) + self.gcs_cfg = data_config.gcs if data_config else GCSConfig.auto() @staticmethod def _check_binary(): @@ -43,8 +44,7 @@ def _check_binary(): if not shell_which(GCSPersistence._GS_UTIL_CLI): raise FlyteUserException("gsutil (gcloud cli) not found! Please install using `pip install gsutil`.") - @staticmethod - def _maybe_with_gsutil_parallelism(*gsutil_args): + def _maybe_with_gsutil_parallelism(self, *gsutil_args): """ Check if we should run `gsutil` with the `-m` flag that enables parallelism via multiple threads/processes. Additional tweaking of @@ -52,7 +52,7 @@ def _maybe_with_gsutil_parallelism(*gsutil_args): https://cloud.google.com/storage/docs/boto-gsutil """ cmd = [GCSPersistence._GS_UTIL_CLI] - if gcp.GSUTIL_PARALLELISM.get(): + if self.gcs_cfg.gsutil_parallelism: cmd.append("-m") cmd.extend(gsutil_args) diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py index 64e09e219c..b73126c17c 100644 --- a/flytekit/extras/persistence/s3_awscli.py +++ b/flytekit/extras/persistence/s3_awscli.py @@ -1,3 +1,4 @@ +import os import os as _os import re as _re import string as _string @@ -6,30 +7,34 @@ from shutil import which as shell_which from typing import Dict, List, Optional -from flytekit.configuration import aws +from flytekit.configuration import DataConfig, S3Config, internal from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.exceptions.user import FlyteUserException from flytekit.loggers import logger from flytekit.tools import subprocess S3_ANONYMOUS_FLAG = "--no-sign-request" +S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" +S3_SECRET_ACCESS_KEY_ENV_NAME = "AWS_SECRET_ACCESS_KEY" -def _update_cmd_config_and_execute(cmd: List[str]): +def _update_cmd_config_and_execute(s3_cfg: S3Config, cmd: List[str]): env = _os.environ.copy() - if aws.ENABLE_DEBUG.get(): + if s3_cfg.enable_debug: cmd.insert(1, "--debug") - if aws.S3_ENDPOINT.get() is not None: - cmd.insert(1, aws.S3_ENDPOINT.get()) - cmd.insert(1, aws.S3_ENDPOINT_ARG_NAME) + if s3_cfg.endpoint is not None: + cmd.insert(1, s3_cfg.endpoint) + cmd.insert(1, "--endpoint-url") - if aws.S3_ACCESS_KEY_ID.get() is not None: - env[aws.S3_ACCESS_KEY_ID_ENV_NAME] = aws.S3_ACCESS_KEY_ID.get() + if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ: + if s3_cfg.access_key_id: + env[S3_ACCESS_KEY_ID_ENV_NAME] = s3_cfg.access_key_id - if aws.S3_SECRET_ACCESS_KEY.get() is not None: - env[aws.S3_SECRET_ACCESS_KEY_ENV_NAME] = aws.S3_SECRET_ACCESS_KEY.get() + if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ: + if s3_cfg.secret_access_key: + env[S3_SECRET_ACCESS_KEY_ENV_NAME] = s3_cfg.secret_access_key retry = 0 while True: @@ -48,11 +53,11 @@ def _update_cmd_config_and_execute(cmd: List[str]): except Exception as e: logger.error(f"Exception when trying to execute {cmd}, reason: {str(e)}") retry += 1 - if retry > aws.RETRIES.get(): + if retry > s3_cfg.retries: raise - secs = aws.BACKOFF_SECONDS.get() - logger.info(f"Sleeping before retrying again, after {secs} seconds") - time.sleep(secs) + secs = s3_cfg.backoff + logger.info(f"Sleeping before retrying again, after {secs.total_seconds()} seconds") + time.sleep(secs.total_seconds()) logger.info("Retrying again") @@ -82,8 +87,9 @@ class S3Persistence(DataPersistence): _AWS_CLI = "aws" _SHARD_CHARACTERS = [str(x) for x in range(10)] + list(_string.ascii_lowercase) - def __init__(self, default_prefix: Optional[str] = None): + def __init__(self, default_prefix: Optional[str] = None, data_config: typing.Optional[DataConfig] = None): super().__init__(name="awscli-s3", default_prefix=default_prefix) + self.s3_cfg = data_config.s3 if data_config else S3Config.auto() @staticmethod def _check_binary(): @@ -122,7 +128,7 @@ def exists(self, remote_path): file_path, ] try: - _update_cmd_config_and_execute(cmd) + _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) return True except Exception as ex: # The s3api command returns an error if the object does not exist. The error message contains @@ -144,7 +150,7 @@ def get(self, from_path: str, to_path: str, recursive: bool = False): cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive", from_path, to_path] else: cmd = [S3Persistence._AWS_CLI, "s3", "cp", from_path, to_path] - return _update_cmd_config_and_execute(cmd) + return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) def put(self, from_path: str, to_path: str, recursive: bool = False): extra_args = { @@ -160,7 +166,7 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): cmd += ["--recursive"] cmd.extend(_extra_args(extra_args)) cmd += [from_path, to_path] - return _update_cmd_config_and_execute(cmd) + return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: paths = list(paths) # make type check happy diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index e4a803f50a..1018b5254b 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -9,8 +9,8 @@ import pandas as pd from flytekit import FlyteContext, kwtypes +from flytekit.configuration import SerializationSettings from flytekit.core.base_sql_task import SQLTask -from flytekit.core.context_manager import SerializationSettings from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor from flytekit.models import task as task_models diff --git a/flytekit/interfaces/stats/client.py b/flytekit/interfaces/stats/client.py index 759979670a..a49b482c09 100644 --- a/flytekit/interfaces/stats/client.py +++ b/flytekit/interfaces/stats/client.py @@ -5,7 +5,7 @@ import statsd -from flytekit.configuration import statsd as _statsd_config +from flytekit.configuration import StatsConfig RESERVED_TAG_WORDS = frozenset( ["asg", "az", "backend", "canary", "host", "period", "region", "shard", "window", "source"] @@ -129,21 +129,21 @@ def _prefix(self): return self._scope_prefix -def _get_stats_client(): +def _get_stats_client(cfg: StatsConfig): global _stats_client - if _statsd_config.DISABLED.get() is True: + if cfg.disabled is True: _stats_client = DummyStatsClient() if _stats_client is None: - _stats_client = statsd.StatsClient(_statsd_config.HOST.get(), _statsd_config.PORT.get()) + _stats_client = statsd.StatsClient(cfg.host, cfg.port) return _stats_client -def get_base_stats(prefix): - return StatsClientProxy(_get_stats_client(), prefix=prefix) +def get_base_stats(cfg: StatsConfig, prefix: str): + return StatsClientProxy(_get_stats_client(cfg), prefix=prefix) -def get_stats(prefix): - return get_base_stats(prefix) +def get_stats(cfg: StatsConfig, prefix: str): + return get_base_stats(cfg, prefix) class DummyStatsClient(statsd.StatsClient): diff --git a/flytekit/interfaces/stats/taggable.py b/flytekit/interfaces/stats/taggable.py index b2732bc794..09d8cb4ddd 100644 --- a/flytekit/interfaces/stats/taggable.py +++ b/flytekit/interfaces/stats/taggable.py @@ -1,4 +1,6 @@ -from flytekit.configuration.statsd import DISABLE_TAGS +from typing import Dict, List + +from flytekit.configuration import StatsConfig from flytekit.interfaces.stats import client as _stats_client @@ -6,10 +8,11 @@ class TaggableStats(_stats_client.ScopeableStatsProxy): # List of functions we will proxy and prefix the first string argument on EXTENDABLE_FUNC = ["incr", "decr", "timing", "timer", "gauge", "set"] - def __init__(self, client, full_prefix, prefix=None, tags=None): + def __init__(self, client, full_prefix, cfg: StatsConfig, prefix=None, tags=None): super(TaggableStats, self).__init__(client, prefix=prefix) self._tags = tags if tags else {} self._full_prefix = full_prefix + self._cfg = cfg def _create_wrapped_function(self, base_func): if self._scope_prefix: @@ -20,7 +23,7 @@ def name_wrap(stat, *args, **kwargs): if kwargs.pop("per_host", False): tags["_f"] = "i" - if bool(tags) and not DISABLE_TAGS.get(): + if bool(tags) and not self._cfg.disabled_tags: stat = self._serialize_tags(stat, tags) return base_func(self._p_with_prefix(stat), *args, **kwargs) @@ -32,7 +35,7 @@ def name_wrap(stat, *args, **kwargs): if kwargs.pop("per_host", False): tags["_f"] = "i" - if bool(tags) and not DISABLE_TAGS.get(): + if bool(tags) and not self._cfg.disabled_tags: stat = self._serialize_tags(stat, tags) return base_func(stat, *args, **kwargs) @@ -48,6 +51,7 @@ def pipeline(self): return TaggableStats( self._client.pipeline(), self._full_prefix, + cfg=self._cfg, prefix=self._scope_prefix, tags=dict(self._tags), ) @@ -56,6 +60,7 @@ def __enter__(self): return TaggableStats( self._client.__enter__(), self._full_prefix, + cfg=self._cfg, prefix=self._scope_prefix, tags=dict(self._tags), ) @@ -79,13 +84,13 @@ def full_prefix(self): return self._full_prefix -def get_stats(prefix, tags=None): +def get_stats(cfg: StatsConfig, prefix: str, tags: Dict[str, str] = None) -> TaggableStats: """ :rtype: TaggableStats """ # If tagging is disabled, do not pass tags to the constructor. - if DISABLE_TAGS.get(): + if cfg.disabled_tags: tags = None - return TaggableStats(_stats_client.get_base_stats(prefix.lower()), prefix.lower(), tags=tags) + return TaggableStats(_stats_client.get_base_stats(cfg, prefix.lower()), prefix.lower(), cfg=cfg, tags=tags) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index caf2471e58..70af841cd0 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -755,6 +755,9 @@ def env(self): """ return self._env + def add_env(self, key: str, val: str): + self._env[key] = val + @property def config(self): """ diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 56b93f6715..3f86c0f948 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -10,20 +10,15 @@ import typing import uuid from collections import OrderedDict -from copy import deepcopy from dataclasses import asdict, dataclass from datetime import datetime, timedelta -import grpc from flyteidl.core import literals_pb2 as literals_pb2 from flytekit.clients.friendly import SynchronousFlyteClient -from flytekit.configuration import internal -from flytekit.configuration import platform as platform_config -from flytekit.configuration import sdk as sdk_config -from flytekit.configuration import set_flyte_config_file from flytekit.core import constants, context_manager, utils from flytekit.core.interface import Interface +from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException from flytekit.loggers import remote_logger @@ -36,19 +31,16 @@ from singledispatchmethod import singledispatchmethod from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions -from flytekit.clis.flyte_cli.main import _detect_default_config_file -from flytekit.clis.sdk_in_container import serialize -from flytekit.configuration import auth as auth_config -from flytekit.configuration.internal import DOMAIN, PROJECT +from flytekit.configuration import Config, SerializationSettings from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings, get_image_config +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.launch_plan import LaunchPlan from flytekit.core.type_engine import LiteralsResolver, TypeEngine from flytekit.core.workflow import WorkflowBase from flytekit.models import common as common_models -from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models +from flytekit.models import security from flytekit.models.admin.common import Sort from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier from flytekit.models.core.workflow import NodeMetadata @@ -64,13 +56,34 @@ from flytekit.remote.nodes import FlyteNode from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow -from flytekit.tools.translator import FlyteControlPlaneEntity, FlyteLocalEntity, get_serializable +from flytekit.tools.translator import get_serializable_launch_plan, get_serializable_task, get_serializable_workflow ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse] MOST_RECENT_FIRST = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING) +@dataclass +class Options(object): + """ + Args: + raw_data_prefix: str -> remote prefix for storage location of the form ``s3:///key...`` or + ``gcs://...`` or ``file://...``. If not specified will use the platform configured default. + auth_role: common_models.AuthRole -> Specifies the Kubernetes Service account, IAM role etc to be used. If not + specified defaults will be used + + """ + + raw_data_prefix: typing.Optional[str] = None + auth_role: typing.Optional[common_models.AuthRole] = None + labels: typing.Optional[common_models.Labels] = None + annotations: typing.Optional[common_models.Annotations] = None + security_context: typing.Optional[security.SecurityContext] = None + max_parallelism: typing.Optional[int] = None + notifications: typing.Optional[typing.List[common_models.Notification]] = None + disable_notifications: typing.Optional[bool] = None + + @dataclass class ResolvedIdentifiers: project: str @@ -121,132 +134,42 @@ class FlyteRemote(object): """ - @classmethod - def from_config( - cls, - default_project: typing.Optional[str] = None, - default_domain: typing.Optional[str] = None, - config_file_path: typing.Optional[str] = None, - grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None, - venv_root: typing.Optional[str] = None, - ) -> FlyteRemote: - """Create a FlyteRemote object using flyte configuration variables and/or environment variable overrides. - - :param default_project: default project to use when fetching or executing flyte entities. - :param default_domain: default domain to use when fetching or executing flyte entities. - :param config_file_path: config file to use when connecting to flyte admin. we will use '~/.flyte/config' by default. - :param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned by :func:`grpc.ssl_channel_credentials` - """ - - if config_file_path is None: - _detect_default_config_file() - else: - set_flyte_config_file(config_file_path) - - raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join( - sdk_config.LOCAL_SANDBOX.get(), "control_plane_raw" - ) - - file_access = FileAccessProvider( - local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"), - raw_output_prefix=raw_output_data_prefix, - ) - - venv_root = venv_root or serialize._DEFAULT_FLYTEKIT_VIRTUALENV_ROOT - entrypoint = context_manager.EntrypointSettings( - path=os.path.join(venv_root, serialize._DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC) - ) - - return cls( - flyte_admin_url=platform_config.URL.get(), - insecure=platform_config.INSECURE.get(), - default_project=default_project or PROJECT.get() or None, - default_domain=default_domain or DOMAIN.get() or None, - file_access=file_access, - auth_role=common_models.AuthRole( - assumable_iam_role=auth_config.ASSUMABLE_IAM_ROLE.get(), - kubernetes_service_account=auth_config.KUBERNETES_SERVICE_ACCOUNT.get(), - ), - notifications=None, - labels=None, - annotations=None, - image_config=get_image_config(), - raw_output_data_config=( - common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None - ), - grpc_credentials=grpc_credentials, - entrypoint_settings=entrypoint, - ) - def __init__( self, - flyte_admin_url: str, - insecure: bool, + config: Config, default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None, file_access: typing.Optional[FileAccessProvider] = None, - auth_role: typing.Optional[common_models.AuthRole] = None, - notifications: typing.Optional[typing.List[common_models.Notification]] = None, - labels: typing.Optional[common_models.Labels] = None, - annotations: typing.Optional[common_models.Annotations] = None, - image_config: typing.Optional[ImageConfig] = None, - raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None, - grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None, - entrypoint_settings: typing.Optional[context_manager.EntrypointSettings] = None, + **kwargs, ): """Initialize a FlyteRemote object. - :param flyte_admin_url: url pointing to the remote backend. - :param insecure: whether or not the enable SSL. + :type kwargs: All arguments that can be passed to create the SynchronousFlyteClient. These are usually grpc + parameters, if you want to customize credentials, ssl handling etc. :param default_project: default project to use when fetching or executing flyte entities. :param default_domain: default domain to use when fetching or executing flyte entities. :param file_access: file access provider to use for offloading non-literal inputs/outputs. - :param auth_role: auth role config - :param notifications: notification config - :param labels: label config - :param annotations: annotation config - :param image_config: image config - :param raw_output_data_config: location for offloaded data, e.g. in S3 - :param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned - by :func:`grpc.ssl_channel_credentials` - :param entrypoint_settings: EntrypointSettings object for use with Spark tasks. If supplied, this will be - used when serializing Spark tasks, which need to know the path to the flytekit entrypoint.py file, - inside the container. """ - remote_logger.warning("This feature is still in beta. Its interface and UX is subject to change.") - if flyte_admin_url is None: - raise user_exceptions.FlyteAssertion("Cannot find flyte admin url in config file.") + if config is None or config.platform is None or config.platform.endpoint is None: + raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.") - self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure, credentials=grpc_credentials) + self._client = SynchronousFlyteClient(config.platform, **kwargs) + self._config = config # read config files, env vars, host, ssl options for admin client - self._flyte_admin_url = flyte_admin_url - self._insecure = insecure self._default_project = default_project self._default_domain = default_domain - self._image_config = image_config - self._auth_role = auth_role - self._notifications = notifications - self._labels = labels - self._annotations = annotations - self._raw_output_data_config = raw_output_data_config - # Not exposing this as a property for now. - self._entrypoint_settings = entrypoint_settings - - raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join( - sdk_config.LOCAL_SANDBOX.get(), "control_plane_raw" - ) + self._file_access = file_access or FileAccessProvider( - local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"), - raw_output_prefix=raw_output_data_prefix, + local_sandbox_dir=os.path.join(config.local_sandbox_path, "control_plane_metadata"), + raw_output_prefix="/tmp", + data_config=config.data_config, ) + # Save the file access object locally, but also make it available for use from the context. FlyteContextManager.with_context( FlyteContextManager.current_context().with_file_access(self._file_access).build() ) - # TODO: Reconsider whether we want this. Probably best to not cache. - self._serialized_entity_cache = OrderedDict() - @property def client(self) -> SynchronousFlyteClient: """Return a SynchronousFlyteClient for additional operations.""" @@ -263,93 +186,21 @@ def default_domain(self) -> str: return self._default_domain @property - def image_config(self) -> ImageConfig: + def config(self) -> Config: """Image config.""" - return self._image_config + return self._config @property def file_access(self) -> FileAccessProvider: """File access provider to use for offloading non-literal inputs/outputs.""" return self._file_access - @property - def auth_role(self): - """Auth role config.""" - return self._auth_role - - @property - def notifications(self): - """Notification config.""" - return self._notifications - - @property - def labels(self): - """Label config.""" - return self._labels - - @property - def annotations(self): - """Annotation config.""" - return self._annotations - - @property - def raw_output_data_config(self): - """Location for offloaded data, e.g. in S3""" - return self._raw_output_data_config - - @property - def version(self) -> str: - """Get a randomly generated version string.""" - return uuid.uuid4().hex[:30] + str(int(time.time())) - def remote_context(self): """Context manager with remote-specific configuration.""" return FlyteContextManager.with_context( FlyteContextManager.current_context().with_file_access(self.file_access) ) - def with_overrides( - self, - default_project: typing.Optional[str] = None, - default_domain: typing.Optional[str] = None, - flyte_admin_url: typing.Optional[str] = None, - insecure: typing.Optional[bool] = None, - file_access: typing.Optional[FileAccessProvider] = None, - auth_role: typing.Optional[common_models.AuthRole] = None, - notifications: typing.Optional[typing.List[common_models.Notification]] = None, - labels: typing.Optional[common_models.Labels] = None, - annotations: typing.Optional[common_models.Annotations] = None, - image_config: typing.Optional[ImageConfig] = None, - raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None, - ): - """Create a copy of the remote object, overriding the specified attributes.""" - new_remote = deepcopy(self) - if default_project: - new_remote._default_project = default_project - if default_domain: - new_remote._default_domain = default_domain - if flyte_admin_url: - new_remote._flyte_admin_url = flyte_admin_url - new_remote._client = SynchronousFlyteClient(flyte_admin_url, self._insecure) - if insecure: - new_remote._insecure = insecure - new_remote._client = SynchronousFlyteClient(self._flyte_admin_url, insecure) - if file_access: - new_remote._file_access = file_access - if auth_role: - new_remote._auth_role = auth_role - if notifications: - new_remote._notifications = notifications - if labels: - new_remote._labels = labels - if annotations: - new_remote._annotations = annotations - if image_config: - new_remote._image_config = image_config - if raw_output_data_config: - new_remote._raw_output_data_config = raw_output_data_config - return new_remote - def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask: """Fetch a task entity from flyte admin. @@ -512,126 +363,156 @@ def list_tasks_by_version( ) return [FlyteTask.promote_from_model(t.closure.compiled_task.template) for t in t_models] - ###################### - # Serialize Entities # - ###################### - - @singledispatchmethod - def _serialize( - self, - entity: FlyteLocalEntity, - project: str = None, - domain: str = None, - version: str = None, - **kwargs, - ) -> FlyteControlPlaneEntity: - """Serialize an entity for registration.""" - # TODO: Revisit cache - return get_serializable( - self._serialized_entity_cache, - SerializationSettings( - project or self.default_project, - domain or self.default_domain, - version or self.version, - self.image_config, - # https://github.com/flyteorg/flyte/issues/1359 - env={internal.IMAGE.env_var: self.image_config.default_image.full}, - entrypoint_settings=self._entrypoint_settings, - ), - entity=entity, - ) - ##################### # Register Entities # ##################### - @singledispatchmethod - def register( - self, - entity: typing.Union[PythonTask, WorkflowBase, LaunchPlan], - project: str = None, - domain: str = None, - name: str = None, - version: str = None, - ) -> typing.Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan]: - """Register an entity to flyte admin. - - :param entity: entity to register. - :param project: register entity into this project. If None, uses ``default_project`` attribute - :param domain: register entity into this domain. If None, uses ``default_domain`` attribute - :param name: register entity with this name. If None, uses ``entity.name`` - :param version: register entity with this version. If None, uses auto-generated version. - """ - raise NotImplementedError(f"entity type {type(entity)} not recognized for registration") + def _resolve_identifier(self, t: int, name: str, version: str, ss: SerializationSettings) -> Identifier: + ident = Identifier( + resource_type=t, + project=ss.project or self.default_project, + domain=ss.domain or self.default_domain, + name=name, + version=version or ss.version, + ) + if not ident.project or not ident.domain or not ident.name or not ident.version: + raise ValueError( + f"To register a new {ident.resource_type}, (project, domain, name, version) required, " + f"received ({ident.project}, {ident.domain}, {ident.name}, {ident.version})." + ) + return ident - @register.register - def _( - self, entity: PythonTask, project: str = None, domain: str = None, name: str = None, version: str = None + def register_task( + self, entity: PythonTask, serialization_settings: SerializationSettings, version: typing.Optional[str] = None ) -> FlyteTask: - """Register an @task-decorated function or TaskTemplate task to flyte admin.""" - resolved_identifiers = asdict(self._resolve_identifier_kwargs(entity, project, domain, name, version)) - self.client.create_task( - Identifier(ResourceType.TASK, **resolved_identifiers), - task_spec=self._serialize(entity, **resolved_identifiers), - ) - return self.fetch_task(**resolved_identifiers) + """ + Register a qualified task (PythonTask) with Remote + For any conflicting parameters method arguments are regarded as overrides - @register.register - def _( - self, entity: WorkflowBase, project: str = None, domain: str = None, name: str = None, version: str = None - ) -> FlyteWorkflow: - """Register an @workflow-decorated function to flyte admin.""" - resolved_identifiers = asdict(self._resolve_identifier_kwargs(entity, project, domain, name, version)) - self.client.create_workflow( - Identifier(ResourceType.WORKFLOW, **resolved_identifiers), - workflow_spec=self._serialize(entity, **resolved_identifiers), + :param entity: PythonTask can be either @task or a instance of a Task class + :param serialization_settings: Settings that will be used to override various serialization parameters. + :param version: version that will be used to register. If not specified will default to using the serialization settings default + :return: + """ + m = OrderedDict() + task_spec = get_serializable_task(m, serialization_settings, entity) + ident = self._resolve_identifier(ResourceType.TASK, entity.name, version, serialization_settings) + self.client.create_task(ident, task_spec=task_spec) + return self.fetch_task( + serialization_settings.project, + serialization_settings.domain, + entity.name, + version or serialization_settings.version, ) - return self.fetch_workflow(**resolved_identifiers) - @register.register - def _( - self, entity: LaunchPlan, project: str = None, domain: str = None, name: str = None, version: str = None - ) -> FlyteLaunchPlan: - """Register a LaunchPlan object to flyte admin.""" - # See _get_patch_launch_plan_fn for what we need to patch. These are the elements of a launch plan - # that are not set at serialization time and are filled in either by flyte-cli register files or flytectl. - resolved_identifiers = asdict(self._resolve_identifier_kwargs(entity, project, domain, name, version)) - serialized_lp: launch_plan_models.LaunchPlan = self._serialize(entity, **resolved_identifiers) - if self.auth_role: - serialized_lp.spec._auth_role = common_models.AuthRole( - self.auth_role.assumable_iam_role, self.auth_role.kubernetes_service_account - ) - if self.raw_output_data_config: - serialized_lp.spec._raw_output_data_config = common_models.RawOutputDataConfig( - self.raw_output_data_config.output_location_prefix + def register_workflow( + self, + entity: WorkflowBase, + serialization_settings: SerializationSettings, + version: typing.Optional[str] = None, + default_launch_plan: bool = True, + all_downstream: bool = False, + options: typing.Optional[Options] = None, + ) -> FlyteWorkflow: + """ + Use this method to register a workflow. + :param version: version for the entity to be registered as + :param entity: The workflow to be registered + :param serialization_settings: The serialization settings to be used + :param default_launch_plan: This should be true if a default launch plan should be created for the workflow + :param all_downstream: This should be true if all downstream entities should be registered, including tasks, + subworkflows, launchplans + :param options: Additional execution options that can be configured for the default launchplan + :return: + """ + m = OrderedDict() + workflow_spec = get_serializable_workflow(m, serialization_settings, entity) + ident = self._resolve_identifier(ResourceType.WORKFLOW, entity.name, version, serialization_settings) + self.client.create_workflow(ident, workflow_spec=workflow_spec) + if default_launch_plan: + default_lp = LaunchPlan.get_default_launch_plan(FlyteContextManager.current_context(), entity) + self.register_launch_plan(default_lp, serialization_settings, version=version, options=options) + remote_logger.debug("Created default launch plan for Workflow") + + if all_downstream: + self._register_entity_if_not_exists( + entity, serialization_settings=serialization_settings, version=version, options=options ) - # Patch in labels and annotations - if self.labels: - for k, v in self.labels.values.items(): - serialized_lp.spec._labels.values[k] = v + return self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) - if self.annotations: - for k, v in self.annotations.values.items(): - serialized_lp.spec._annotations.values[k] = v + def register_launch_plan( + self, + entity: LaunchPlan, + serialization_settings: typing.Optional[SerializationSettings] = None, + version: typing.Optional[str] = None, + options: typing.Optional[Options] = None, + ) -> FlyteLaunchPlan: + """ + Register a given launchplan, possibly applying overrides from the provided options. + Note: In this case it is reasonable to have image_config in SerializationSettings to be None! + :param entity: Launchplan to be registered + :param serialization_settings: Settings to use for Serialization + :param version: + :param options: + :return: + """ + if not options: + options = Options() - self.client.create_launch_plan( - Identifier(ResourceType.LAUNCH_PLAN, **resolved_identifiers), - launch_plan_spec=serialized_lp.spec, + if serialization_settings is None: + serialization_settings = SerializationSettings( + image_config=None, project=self.default_project, domain=self.default_domain + ) + + raw = None + if options.raw_data_prefix: + raw = common_models.RawOutputDataConfig(options.raw_data_prefix) + + lp = entity.clone_with( + name=entity.name, + raw_output_data_config=raw, + auth_role=options.auth_role, + max_parallelism=options.max_parallelism, + notifications=options.notifications, + labels=options.labels, + annotations=options.annotations, ) - return self.fetch_launch_plan(**resolved_identifiers) + ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, serialization_settings) + m = OrderedDict() + idl_lp = get_serializable_launch_plan(m, serialization_settings, lp) + self.client.create_launch_plan(ident, idl_lp.spec) + return self.fetch_launch_plan(ident.project, ident.domain, ident.name, ident.version) - def _register_entity_if_not_exists(self, entity: WorkflowBase, resolved_identifiers_dict: dict): - # Try to register all the entity in WorkflowBase including LaunchPlan, PythonTask, or subworkflow. - node_identifiers_dict = deepcopy(resolved_identifiers_dict) + def _register_entity_if_not_exists( + self, + entity: WorkflowBase, + serialization_settings: SerializationSettings, + version: typing.Optional[str] = None, + options: typing.Optional[Options] = None, + ): for node in entity.nodes: try: - node_identifiers_dict["name"] = node.flyte_entity.name if isinstance(node.flyte_entity, WorkflowBase): - self._register_entity_if_not_exists(node.flyte_entity, node_identifiers_dict) - self.register(node.flyte_entity, **node_identifiers_dict) - elif isinstance(node.flyte_entity, PythonTask) or isinstance(node.flyte_entity, LaunchPlan): - self.register(node.flyte_entity, **node_identifiers_dict) + self.register_workflow( + node.flyte_entity, + serialization_settings, + version=version, + default_launch_plan=True, + all_downstream=True, + options=options, + ) + elif isinstance(node.flyte_entity, PythonTask): + self.register_task( + node.flyte_entity, serialization_settings=serialization_settings, version=version + ) + elif isinstance(node.flyte_entity, LaunchPlan): + self.register_launch_plan( + node.flyte_entity, + serialization_settings=serialization_settings, + version=version, + options=options, + ) else: raise NotImplementedError(f"We don't support registering this kind of entity: {node.flyte_entity}") except FlyteEntityAlreadyExistsException: @@ -643,64 +524,15 @@ def _register_entity_if_not_exists(self, entity: WorkflowBase, resolved_identifi # Execute Entities # #################### - def _resolve_identifier_kwargs( - self, - entity, - project: typing.Optional[str], - domain: typing.Optional[str], - name: typing.Optional[str], - version: typing.Optional[str], - ) -> ResolvedIdentifiers: - """ - Resolves the identifier attributes based on user input, falling back on the default project/domain and - auto-generated version, and ultimately the entity project/domain if entity is a remote flyte entity. - """ - error_msg = ( - "entity {entity} of type {entity_type} is not associated with a {arg_name}. Please specify the {arg_name} " - "argument when invoking the FlyteRemote.execute method or a default_{arg_name} value when initializig the " - "FlyteRemote object." - ) - - if project: - resolved_project, msg_project = project, "execute-method" - elif self.default_project: - resolved_project, msg_project = self.default_project, "remote" - elif hasattr(entity, "id"): - resolved_project, msg_project = entity.id.project, "entity" - else: - raise TypeError(error_msg.format(entity=entity, entity_type=type(entity), arg_name="project")) - - if domain: - resolved_domain, msg_domain = domain, "execute-method" - elif self.default_domain: - resolved_domain, msg_domain = self.default_domain, "remote" - elif hasattr(entity, "id"): - resolved_domain, msg_domain = entity.id.domain, "entity" - else: - raise TypeError(error_msg.format(entity=entity, entity_type=type(entity), arg_name="domain")) - - remote_logger.debug( - f"Using {msg_project}-supplied value for project and {msg_domain}-supplied value for domain." - ) - - return ResolvedIdentifiers( - resolved_project, - resolved_domain, - name or entity.name, - version or self.version, - ) - def _execute( self, entity: typing.Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan], inputs: typing.Dict[str, typing.Any], - project: str, - domain: str, - execution_name: typing.Optional[str] = None, + project: str = None, + domain: str = None, + execution_name: str = None, + options: typing.Optional[Options] = None, wait: bool = False, - labels: typing.Optional[common_models.Labels] = None, - annotations: typing.Optional[common_models.Annotations] = None, - auth_role: typing.Optional[common_models.AuthRole] = None, ) -> FlyteWorkflowExecution: """Common method for execution across all entities. @@ -713,12 +545,15 @@ def _execute( :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` """ execution_name = execution_name or "f" + uuid.uuid4().hex[:19] - disable_all = self.notifications == [] - if disable_all: - notifications = None + if not options: + options = Options() + if options.disable_notifications is not None: + if options.disable_notifications: + notifications = None + else: + notifications = NotificationList(options.notifications) else: - notifications = NotificationList(self.notifications or []) - disable_all = None + notifications = NotificationList([]) with self.remote_context() as ctx: input_python_types = entity.guessed_python_interface.inputs @@ -736,8 +571,8 @@ def _execute( # in the case that I want to use a flyte entity from e.g. project "A" but actually execute the entity on a # different project "B". For now, this method doesn't support this use case. exec_id = self.client.create_execution( - project, - domain, + project or self.default_project, + domain or self.default_domain, execution_name, ExecutionSpec( entity.id, @@ -747,20 +582,54 @@ def _execute( 0, ), notifications=notifications, - disable_all=disable_all, - labels=labels or self.labels, - annotations=annotations or self.annotations, - auth_role=auth_role or self.auth_role, + disable_all=options.disable_notifications, + labels=options.labels, + annotations=options.annotations, + auth_role=options.auth_role, + max_parallelism=options.max_parallelism, ), literal_inputs, ) - except user_exceptions.FlyteEntityAlreadyExistsException: - exec_id = WorkflowExecutionIdentifier(flyte_id.project, flyte_id.domain, execution_name) + except user_exceptions.FlyteEntityAlreadyExistsException as e: + remote_logger.warning( + f"Execution with Execution ID {execution_name} already exists. " + f"Assuming this is the same execution, returning!" + ) + exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) execution = FlyteWorkflowExecution.promote_from_model(self.client.get_execution(exec_id)) if wait: return self.wait(execution) return execution + def _resolve_identifier_kwargs( + self, + entity: typing.Any, + project: str, + domain: str, + name: str, + from_project: str, + from_domain: str, + version: str, + ) -> ResolvedIdentifiers: + """ + Resolves the identifier attributes based on user input, falling back on the default project/domain and + auto-generated version, and ultimately the entity project/domain if entity is a remote flyte entity. + """ + ident = ResolvedIdentifiers( + project=from_project or project or self.default_project, + domain=from_domain or domain or self.default_domain, + name=name or entity.name, + version=version, + ) + if not (ident.project and ident.domain and ident.name): + raise ValueError( + f"Cannot launch an execution with missing project/domain/name {ident} for entity type {type(entity)}." + f" Specify them in the execute method or when intializing FlyteRemote" + ) + return ident + @singledispatchmethod def execute( self, @@ -771,6 +640,9 @@ def execute( name: str = None, version: str = None, execution_name: str = None, + from_project: str = None, + from_domain: str = None, + options: typing.Optional[Options] = None, wait: bool = False, ) -> FlyteWorkflowExecution: """Execute a task, workflow, or launchplan. @@ -781,6 +653,9 @@ def execute( - ``@workflow``-decorated functions. - ``LaunchPlan`` objects. + :param options: + :param from_domain: + :param from_project: :param entity: entity to execute :param inputs: dictionary mapping argument names to values :param project: execute entity in this project. If entity doesn't exist in the project, register the entity @@ -813,6 +688,9 @@ def _( name: str = None, version: str = None, execution_name: str = None, + from_project: str = None, + from_domain: str = None, + options: typing.Optional[Options] = None, wait: bool = False, ) -> FlyteWorkflowExecution: """Execute a FlyteTask, or FlyteLaunchplan. @@ -821,24 +699,14 @@ def _( """ if name or version: remote_logger.warning(f"The 'name' and 'version' arguments are ignored for entities of type {type(entity)}") - resolved_identifiers = self._resolve_identifier_kwargs( - entity, project, domain, entity.id.name, entity.id.version - ) return self._execute( entity, inputs, - project=resolved_identifiers.project, - domain=resolved_identifiers.domain, + project=project, + domain=domain, execution_name=execution_name, wait=wait, - labels=entity.labels if isinstance(entity, FlyteLaunchPlan) and entity.labels.values else None, - annotations=entity.annotations - if isinstance(entity, FlyteLaunchPlan) and entity.annotations.values - else None, - auth_role=entity.auth_role - if isinstance(entity, FlyteLaunchPlan) - and (entity.auth_role.assumable_iam_role or entity.auth_role.kubernetes_service_account) - else None, + options=options, ) @execute.register @@ -851,6 +719,9 @@ def _( name: str = None, version: str = None, execution_name: str = None, + from_project: str = None, + from_domain: str = None, + options: typing.Optional[Options] = None, wait: bool = False, ) -> FlyteWorkflowExecution: """Execute a FlyteWorkflow. @@ -859,17 +730,19 @@ def _( """ if name or version: remote_logger.warning(f"The 'name' and 'version' arguments are ignored for entities of type {type(entity)}") - resolved_identifiers = self._resolve_identifier_kwargs( - entity, project, domain, entity.id.name, entity.id.version - ) launch_plan = self.fetch_launch_plan(entity.id.project, entity.id.domain, entity.id.name, entity.id.version) return self.execute( launch_plan, inputs, - project=resolved_identifiers.project, - domain=resolved_identifiers.domain, + project=project, + domain=domain, execution_name=execution_name, + options=options, wait=wait, + from_domain=from_domain, + from_project=from_project, + version=version, + name=name, ) # Flytekit Entities @@ -885,15 +758,33 @@ def _( name: str = None, version: str = None, execution_name: str = None, + from_project: str = None, + from_domain: str = None, + options: typing.Optional[Options] = None, wait: bool = False, ) -> FlyteWorkflowExecution: - """Execute an @task-decorated function or TaskTemplate task.""" + """ + Execute an @task-decorated function or TaskTemplate task. + TODO: We should not fetch the entity first, we should always register it. The version should be computed using + the hash of pickle? + """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) try: flyte_task: FlyteTask = self.fetch_task(**resolved_identifiers_dict) except Exception: - flyte_task: FlyteTask = self.register(entity, **resolved_identifiers_dict) + if issubclass(entity, PythonAutoContainerTask): + raise ValueError( + f"PythonTask {entity.name} not already registered. It cannot be auto-registered as the container" + f" image cannot be automatically deducted. Please register and then execute." + ) + ss = SerializationSettings( + image_config=None, + project=project or self.default_project, + domain=domain or self._default_domain, + version=version, + ) + flyte_task: FlyteTask = self.register_task(entity, ss) flyte_task.guessed_python_interface = entity.python_interface return self.execute( flyte_task, @@ -914,38 +805,51 @@ def _( name: str = None, version: str = None, execution_name: str = None, + from_project: str = None, + from_domain: str = None, + options: typing.Optional[Options] = None, wait: bool = False, ) -> FlyteWorkflowExecution: """Execute an @workflow-decorated function.""" - resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) + resolved_identifiers = self._resolve_identifier_kwargs( + entity, project, domain, name, from_project, from_domain, version + ) resolved_identifiers_dict = asdict(resolved_identifiers) + ss = SerializationSettings( + image_config=None, + project=project or self.default_project, + domain=domain or self._default_domain, + version=version, + ) try: flyte_workflow: FlyteWorkflow = self.fetch_workflow(**resolved_identifiers_dict) except FlyteEntityNotExistException: remote_logger.info("Try to register FlyteWorkflow because it wasn't found in Flyte Admin!") - self._register_entity_if_not_exists(entity, resolved_identifiers_dict) - flyte_workflow: FlyteWorkflow = self.register(entity, **resolved_identifiers_dict) + flyte_workflow: FlyteWorkflow = self.register_workflow( + entity, ss, version=version, all_downstream=True, options=options + ) flyte_workflow.guessed_python_interface = entity.python_interface ctx = context_manager.FlyteContext.current_context() try: - self.fetch_launch_plan(**resolved_identifiers_dict) + flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict) except FlyteEntityNotExistException: remote_logger.info("Try to register default launch plan because it wasn't found in Flyte Admin!") default_lp = LaunchPlan.get_default_launch_plan(ctx, entity) - self.register(default_lp, **resolved_identifiers_dict) + self.register_launch_plan(default_lp, ss, version=version, options=options) + flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict) return self.execute( - flyte_workflow, + flyte_lp, inputs, - project=resolved_identifiers.project, - domain=resolved_identifiers.domain, + project=project, + domain=domain, execution_name=execution_name, wait=wait, + options=options, ) - @execute.register def _( self, entity: LaunchPlan, @@ -955,23 +859,38 @@ def _( name: str = None, version: str = None, execution_name: str = None, + from_project: str = None, + from_domain: str = None, + options: typing.Optional[Options] = None, wait: bool = False, ) -> FlyteWorkflowExecution: """Execute a LaunchPlan object.""" - resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) - resolved_identifiers_dict = asdict(resolved_identifiers) + resolved_identifiers = self._resolve_identifier_kwargs( + entity, project, domain, name, from_project, from_domain, version + ) + dict_ids = asdict(resolved_identifiers) try: - flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**resolved_identifiers_dict) + flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**dict_ids) except Exception: - flyte_launchplan: FlyteLaunchPlan = self.register(entity, **resolved_identifiers_dict) + ss = SerializationSettings( + image_config=None, project=resolved_identifiers.project, domain=resolved_identifiers.domain + ) + flyte_launchplan: FlyteLaunchPlan = self.register_launch_plan( + entity, serialization_settings=ss, version=resolved_identifiers.version + ) flyte_launchplan.guessed_python_interface = entity.python_interface return self.execute( flyte_launchplan, inputs, - project=resolved_identifiers.project, - domain=resolved_identifiers.domain, + project=project, + domain=domain, execution_name=execution_name, + options=options, wait=wait, + from_domain=from_domain, + from_project=from_project, + version=version, + name=name, ) ################################### diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py new file mode 100644 index 0000000000..318a0e3fa0 --- /dev/null +++ b/flytekit/tools/serialize_helpers.py @@ -0,0 +1,120 @@ +import math +import os as _os +import sys +import typing +from collections import OrderedDict + +import click +from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan +from flyteidl.admin.task_pb2 import TaskSpec as _idl_admin_TaskSpec +from flyteidl.admin.workflow_pb2 import WorkflowSpec as _idl_admin_WorkflowSpec + +from flytekit import LaunchPlan +from flytekit.core import context_manager as flyte_context +from flytekit.core.base_task import PythonTask +from flytekit.core.workflow import WorkflowBase +from flytekit.exceptions.user import FlyteValidationException +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import task as task_models +from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.core import identifier as _identifier +from flytekit.tools.translator import get_serializable + + +def _determine_text_chars(length): + """ + This function is used to help prefix files. If there are only 10 entries, then we just need one digit (0-9) to be + the prefix. If there are 11, then we'll need two (00-10). + + :param int length: + :rtype: int + """ + if length == 0: + return 0 + return math.ceil(math.log(length, 10)) + + +def _should_register_with_admin(entity) -> bool: + """ + This is used in the code below. The translator.py module produces lots of objects (namely nodes and BranchNodes) + that do not/should not be written to .pb file to send to admin. This function filters them out. + """ + return isinstance( + entity, (task_models.TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec) + ) + + +def _find_duplicate_tasks(tasks: typing.List[task_models.TaskSpec]) -> typing.Set[task_models.TaskSpec]: + """ + Given a list of `TaskSpec`, this function returns a set containing the duplicated `TaskSpec` if any exists. + """ + seen: typing.Set[_identifier.Identifier] = set() + duplicate_tasks: typing.Set[task_models.TaskSpec] = set() + for task in tasks: + if task.template.id not in seen: + seen.add(task.template.id) + else: + duplicate_tasks.add(task) + return duplicate_tasks + + +def get_registrable_entities(ctx: flyte_context.FlyteContext) -> typing.List: + """ + Returns all entities that can be serialized and should be sent over to Flyte backend. This will filter any entities + that are not known to Admin + """ + new_api_serializable_entities = OrderedDict() + # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan + # object, which gets added to the FlyteEntities.entities list, which we're iterating over. + for entity in flyte_context.FlyteEntities.entities.copy(): + if isinstance(entity, PythonTask) or isinstance(entity, WorkflowBase) or isinstance(entity, LaunchPlan): + get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity) + + if isinstance(entity, WorkflowBase): + lp = LaunchPlan.get_default_launch_plan(ctx, entity) + get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp) + + new_api_model_values = list(new_api_serializable_entities.values()) + entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values)) + serializable_tasks: typing.List[task_models.TaskSpec] = [ + entity for entity in entities_to_be_serialized if isinstance(entity, task_models.TaskSpec) + ] + # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same + # metadata identifiers (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate + # tasks are considered invalid at registration + # time and usually indicate user error, so we catch this common mistake at serialization time. + duplicate_tasks = _find_duplicate_tasks(serializable_tasks) + if len(duplicate_tasks) > 0: + duplicate_task_names = [task.template.id.name for task in duplicate_tasks] + raise FlyteValidationException( + f"Multiple definitions of the following tasks were found: {duplicate_task_names}" + ) + + return [v.to_flyte_idl() for v in entities_to_be_serialized] + + +def persist_registrable_entities(entities: typing.List, folder: str): + """ + For protobuf serializable list of entities, writes a file with the name if the entity and + enumeration order to the specified folder + """ + zero_padded_length = _determine_text_chars(len(entities)) + for i, entity in enumerate(entities): + name = "" + fname_index = str(i).zfill(zero_padded_length) + if isinstance(entity, _idl_admin_TaskSpec): + name = entity.template.id.name + fname = "{}_{}_1.pb".format(fname_index, entity.template.id.name) + elif isinstance(entity, _idl_admin_WorkflowSpec): + name = entity.template.id.name + fname = "{}_{}_2.pb".format(fname_index, entity.template.id.name) + elif isinstance(entity, _idl_admin_LaunchPlan): + name = entity.id.name + fname = "{}_{}_3.pb".format(fname_index, entity.id.name) + else: + click.secho(f"Entity is incorrect formatted {entity} - type {type(entity)}", fg="red") + sys.exit(-1) + click.secho(f" Packaging {name} -> {fname}", dim=True) + fname = _os.path.join(folder, fname) + with open(fname, "wb") as writer: + writer.write(entity.SerializeToString()) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 05b4224cc9..21668a236f 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -1,10 +1,12 @@ from collections import OrderedDict from typing import Callable, Dict, List, Optional, Tuple, Union +from flytekit import PythonFunctionTask +from flytekit.configuration import SerializationSettings +from flytekit.core import SERIALIZED_CONTEXT_ENV_VAR from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode -from flytekit.core.context_manager import SerializationSettings from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.node import Node from flytekit.core.python_auto_container import PythonAutoContainerTask @@ -103,13 +105,17 @@ def get_serializable_task( # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class entity.set_command_fn(_fast_serialize_command_fn(settings, entity)) + container = entity.get_container(settings) + if container and isinstance(entity, PythonFunctionTask): + if entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: + container.add_env(key=SERIALIZED_CONTEXT_ENV_VAR, val=settings.prepare_for_transport()) tt = task_models.TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), interface=entity.interface, custom=entity.get_custom(settings), - container=entity.get_container(settings), + container=container, task_type_version=entity.task_type_version, security_context=entity.security_context, config=entity.get_config(settings), diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index fa0245236c..5c77eca972 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -5,7 +5,6 @@ import pandas from flytekit import FlyteContext -from flytekit.configuration import sdk from flytekit.core.type_engine import T, TypeEngine, TypeTransformer from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType @@ -56,40 +55,10 @@ def write( ) -class FastParquetIO(ParquetIO): - PARQUET_ENGINE = "fastparquet" - - def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame: - from fastparquet import ParquetFile as _ParquetFile - from fastparquet import thrift_structures as _ts - - # TODO Follow up to figure out if this is not needed anymore - # https://github.com/dask/fastparquet/issues/414#issuecomment-478983811 - df = pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, index=False) - df_column_types = df.dtypes - pf = _ParquetFile(chunk) - schema_column_dtypes = {l.name: l.type for l in list(pf.schema.schema_elements)} - - for idx in df_column_types[df_column_types == "float16"].index.tolist(): - # A hacky way to get the string representations of the column types of a parquet schema - # Reference: - # https://github.com/dask/fastparquet/blob/f4ecc67f50e7bf98b2d0099c9589c615ea4b06aa/fastparquet/schema.py - if _ts.parquet_thrift.Type._VALUES_TO_NAMES[schema_column_dtypes[idx]] == "BOOLEAN": - df[idx] = df[idx].astype("object") - df[idx].replace({0: False, 1: True, pandas.np.nan: None}, inplace=True) - return df - - -_PARQUETIO_ENGINES: typing.Dict[str, ParquetIO] = { - ParquetIO.PARQUET_ENGINE: ParquetIO(), - FastParquetIO.PARQUET_ENGINE: FastParquetIO(), -} - - class PandasSchemaReader(LocalIOSchemaReader[pandas.DataFrame]): def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) - self._parquet_engine = _PARQUETIO_ENGINES[sdk.PARQUET_ENGINE.get()] + self._parquet_engine = ParquetIO() def _read(self, *path: os.PathLike, **kwargs) -> pandas.DataFrame: return self._parquet_engine.read(*path, columns=self.column_names, **kwargs) @@ -98,7 +67,7 @@ def _read(self, *path: os.PathLike, **kwargs) -> pandas.DataFrame: class PandasSchemaWriter(LocalIOSchemaWriter[pandas.DataFrame]): def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) - self._parquet_engine = _PARQUETIO_ENGINES[sdk.PARQUET_ENGINE.get()] + self._parquet_engine = ParquetIO() def _write(self, df: T, path: os.PathLike, **kwargs): return self._parquet_engine.write(df, to_file=path, **kwargs) @@ -111,9 +80,7 @@ class PandasDataFrameTransformer(TypeTransformer[pandas.DataFrame]): def __init__(self): super().__init__("PandasDataFrame<->GenericSchema", pandas.DataFrame) - self._parquet_engine = _PARQUETIO_ENGINES[sdk.PARQUET_ENGINE.get()] - # Pandas dataframes can have their hashes overriden to facilitate the case of caching pandas dataframes by - # value. + self._parquet_engine = ParquetIO() self._hash_overridable = True @staticmethod diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index c4fd015c57..3a4f9c01d7 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -1,7 +1,7 @@ -from flytekit.configuration.sdk import USE_STRUCTURED_DATASET +from flytekit.configuration.internal import LocalSDK from flytekit.loggers import logger -if USE_STRUCTURED_DATASET.get(): +if LocalSDK.USE_STRUCTURED_DATASET.read(): from .basic_dfs import ( ArrowToParquetEncodingHandler, PandasToParquetEncodingHandler, diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 0b36a34302..463be8642c 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -22,7 +22,7 @@ import numpy as _np import pyarrow as pa -from flytekit.configuration.sdk import USE_STRUCTURED_DATASET +from flytekit.configuration.internal import LocalSDK from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.loggers import logger @@ -391,7 +391,7 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = True, override The string "://" should not be present in any handler's protocol so we don't check for it. """ - if not USE_STRUCTURED_DATASET.get(): + if not LocalSDK.USE_STRUCTURED_DATASET.read(): logger.info(f"Structured datasets not enabled, not registering handler {h}") return @@ -724,7 +724,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: raise ValueError(f"StructuredDatasetTransformerEngine cannot reverse {literal_type}") -if USE_STRUCTURED_DATASET.get(): +if LocalSDK.USE_STRUCTURED_DATASET.read(): logger.debug("Structured dataset module load... using structured datasets!") flyte_dataset_transformer = StructuredDatasetTransformerEngine() TypeEngine.register(flyte_dataset_transformer) diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py index 45ece39106..1ae47339b3 100644 --- a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py +++ b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py @@ -3,7 +3,8 @@ from google.protobuf.json_format import MessageToDict -from flytekit.extend import SerializationSettings, SQLTask +from flytekit.configuration import SerializationSettings +from flytekit.extend import SQLTask from flytekit.models.presto import PrestoQuery from flytekit.types.schema import FlyteSchema diff --git a/plugins/flytekit-aws-athena/tests/test_athena.py b/plugins/flytekit-aws-athena/tests/test_athena.py index 9fd8c60762..4489e59e7d 100644 --- a/plugins/flytekit-aws-athena/tests/test_athena.py +++ b/plugins/flytekit-aws-athena/tests/test_athena.py @@ -4,7 +4,8 @@ from flytekitplugins.athena import AthenaConfig, AthenaTask from flytekit import kwtypes, workflow -from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable from flytekit.types.schema import FlyteSchema diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index c9b30b7af1..8787e70011 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -6,7 +6,8 @@ from google.protobuf.struct_pb2 import Struct from flytekit import PythonFunctionTask -from flytekit.extend import SerializationSettings, TaskPlugins +from flytekit.configuration import SerializationSettings +from flytekit.extend import TaskPlugins @dataclass_json diff --git a/plugins/flytekit-aws-batch/tests/test_aws_batch.py b/plugins/flytekit-aws-batch/tests/test_aws_batch.py index 5cd1e5e5c2..73eada2b09 100644 --- a/plugins/flytekit-aws-batch/tests/test_aws_batch.py +++ b/plugins/flytekit-aws-batch/tests/test_aws_batch.py @@ -1,7 +1,7 @@ from flytekitplugins.awsbatch import AWSBatchConfig from flytekit import PythonFunctionTask, task -from flytekit.extend import Image, ImageConfig, SerializationSettings +from flytekit.configuration import Image, ImageConfig, SerializationSettings config = AWSBatchConfig( parameters={"codec": "mp4"}, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py index ae805d0e81..f180d5968c 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py @@ -9,7 +9,8 @@ from google.protobuf.json_format import MessageToDict from flytekit import FlyteContext -from flytekit.extend import DictTransformer, PythonTask, SerializationSettings, TypeEngine, TypeTransformer +from flytekit.configuration import SerializationSettings +from flytekit.extend import DictTransformer, PythonTask, TypeEngine, TypeTransformer from flytekit.models.literals import Literal from flytekit.models.types import LiteralType, SimpleType diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py index 77d5b40781..7e6c0726ce 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py @@ -7,7 +7,8 @@ import flytekit from flytekit import ExecutionParameters, FlyteContextManager, PythonFunctionTask, kwtypes -from flytekit.extend import ExecutionState, IgnoreOutputs, Interface, PythonTask, SerializationSettings, TaskPlugins +from flytekit.configuration import SerializationSettings +from flytekit.extend import ExecutionState, IgnoreOutputs, Interface, PythonTask, TaskPlugins from flytekit.loggers import logger from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FlyteFile diff --git a/plugins/flytekit-aws-sagemaker/tests/test_training.py b/plugins/flytekit-aws-sagemaker/tests/test_training.py index a48d3c9f39..4d33a9e4bb 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_training.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_training.py @@ -14,8 +14,8 @@ import flytekit from flytekit import task +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.context_manager import ExecutionParameters -from flytekit.extend import Image, ImageConfig, SerializationSettings def _get_reg_settings(): diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index b7a5104dea..1d4a7f0dbd 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -6,7 +6,8 @@ from google.protobuf.struct_pb2 import Struct from flytekit import StructuredDataset -from flytekit.extend import SerializationSettings, SQLTask +from flytekit.configuration import SerializationSettings +from flytekit.extend import SQLTask from flytekit.models import task as _task_model diff --git a/plugins/flytekit-bigquery/tests/test_bigquery.py b/plugins/flytekit-bigquery/tests/test_bigquery.py index 78d6c0893f..7f4837ae0d 100644 --- a/plugins/flytekit-bigquery/tests/test_bigquery.py +++ b/plugins/flytekit-bigquery/tests/test_bigquery.py @@ -7,7 +7,8 @@ from google.protobuf.struct_pb2 import Struct from flytekit import StructuredDataset, kwtypes, workflow -from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable query_template = "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions` WHERE @version = 1 LIMIT 10" diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py index 4ec7c6e6ca..b32e01650b 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py @@ -1,11 +1,12 @@ import importlib -from flytekit import USE_STRUCTURED_DATASET, StructuredDatasetTransformerEngine, logger +from flytekit import StructuredDatasetTransformerEngine, logger +from flytekit.configuration import internal from flytekit.types.structured.structured_dataset import GCS, S3 from .persist import FSSpecPersistence -if USE_STRUCTURED_DATASET.get(): +if internal.LocalSDK.USE_STRUCTURED_DATASET.read(): from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py index 8583f47a91..f9587a7562 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py @@ -5,22 +5,32 @@ from fsspec.core import split_protocol from fsspec.registry import known_implementations -from flytekit.configuration import aws as _aws_config +from flytekit.configuration import DataConfig, S3Config, internal from flytekit.extend import DataPersistence, DataPersistencePlugins from flytekit.loggers import logger +S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" +S3_SECRET_ACCESS_KEY_ENV_NAME = "AWS_SECRET_ACCESS_KEY" -def s3_setup_args(): +# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 +# for key and secret +_FSSPEC_S3_KEY_ID = "key" +_FSSPEC_S3_SECRET = "secret" + + +def s3_setup_args(s3_cfg: S3Config): kwargs = {} - if _aws_config.S3_ACCESS_KEY_ID.get() is not None: - os.environ[_aws_config.S3_ACCESS_KEY_ID_ENV_NAME] = _aws_config.S3_ACCESS_KEY_ID.get() + if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ: + if s3_cfg.access_key_id: + kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id - if _aws_config.S3_SECRET_ACCESS_KEY.get() is not None: - os.environ[_aws_config.S3_SECRET_ACCESS_KEY_ENV_NAME] = _aws_config.S3_SECRET_ACCESS_KEY.get() + if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ: + if s3_cfg.secret_access_key: + kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key # S3fs takes this as a special arg - if _aws_config.S3_ENDPOINT.get() is not None: - kwargs["client_kwargs"] = {"endpoint_url": _aws_config.S3_ENDPOINT.get()} + if s3_cfg.endpoint is not None: + kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} return kwargs @@ -33,9 +43,10 @@ class FSSpecPersistence(DataPersistence): method """ - def __init__(self, default_prefix=None): + def __init__(self, default_prefix=None, data_config: typing.Optional[DataConfig] = None): super(FSSpecPersistence, self).__init__(name="fsspec-persistence", default_prefix=default_prefix) self.default_protocol = self.get_protocol(default_prefix) + self._data_cfg = data_config if data_config else DataConfig.auto() @staticmethod def get_protocol(path: typing.Optional[str] = None): @@ -48,14 +59,13 @@ def get_protocol(path: typing.Optional[str] = None): protocol = "file" return protocol - @staticmethod - def get_filesystem(path: str) -> fsspec.AbstractFileSystem: + def get_filesystem(self, path: str) -> fsspec.AbstractFileSystem: protocol = FSSpecPersistence.get_protocol(path) kwargs = {} if protocol == "file": kwargs = {"auto_mkdir": True} elif protocol == "s3": - kwargs = s3_setup_args() + kwargs = s3_setup_args(self._data_cfg.s3) return fsspec.filesystem(protocol, **kwargs) # type: ignore @staticmethod diff --git a/plugins/flytekit-hive/flytekitplugins/hive/task.py b/plugins/flytekit-hive/flytekitplugins/hive/task.py index 3280f4fb7a..2781275c86 100644 --- a/plugins/flytekit-hive/flytekitplugins/hive/task.py +++ b/plugins/flytekit-hive/flytekitplugins/hive/task.py @@ -3,7 +3,8 @@ from google.protobuf.json_format import MessageToDict -from flytekit.extend import SerializationSettings, SQLTask +from flytekit.configuration import SerializationSettings +from flytekit.extend import SQLTask from flytekit.models.qubole import HiveQuery, QuboleHiveJob from flytekit.types.schema import FlyteSchema diff --git a/plugins/flytekit-hive/tests/test_hive_task.py b/plugins/flytekit-hive/tests/test_hive_task.py index 372fe5a868..137dd15dd0 100644 --- a/plugins/flytekit-hive/tests/test_hive_task.py +++ b/plugins/flytekit-hive/tests/test_hive_task.py @@ -5,7 +5,8 @@ from flytekitplugins.hive.task import HiveConfig, HiveSelectTask, HiveTask from flytekit import kwtypes, workflow -from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable from flytekit.testing import task_mock from flytekit.types.schema import FlyteSchema diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index e9d7c93d95..7d16ba0b14 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -5,8 +5,9 @@ from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements from flytekit import FlyteContext, PythonFunctionTask +from flytekit.configuration import SerializationSettings from flytekit.exceptions import user as _user_exceptions -from flytekit.extend import Promise, SerializationSettings, TaskPlugins +from flytekit.extend import Promise, TaskPlugins from flytekit.loggers import logger from flytekit.models import task as _task_models diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 85fd2d8041..e9814c98d3 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -8,10 +8,10 @@ from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1VolumeMount from flytekit import Resources, TaskMetadata, dynamic, map_task, task +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager -from flytekit.core.context_manager import FastSerializationSettings from flytekit.core.type_engine import TypeEngine -from flytekit.extend import ExecutionState, Image, ImageConfig, SerializationSettings +from flytekit.extend import ExecutionState from flytekit.tools.translator import get_serializable @@ -423,7 +423,11 @@ def dynamic_task_with_pod_subtask(dummy_input: str) -> str: version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), - fast_serialization_settings=FastSerializationSettings(enabled=True), + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir="/User/flyte/workflows", + distribution_location="s3://my-s3-bucket/fast/123", + ), ) with context_manager.FlyteContextManager.with_context( @@ -433,10 +437,6 @@ def dynamic_task_with_pod_subtask(dummy_input: str) -> str: ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, - additional_context={ - "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", - "dynamic_dest_dir": "/User/flyte/workflows", - }, ) ) ) as ctx: diff --git a/plugins/flytekit-kf-mpi/__init__.py b/plugins/flytekit-kf-mpi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index acdec6e175..6f207b421d 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -9,7 +9,8 @@ from google.protobuf.json_format import MessageToDict from flytekit import PythonFunctionTask -from flytekit.extend import SerializationSettings, TaskPlugins +from flytekit.configuration import SerializationSettings +from flytekit.extend import TaskPlugins from flytekit.models import common as _common diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 014237fc0f..24dc441c63 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,8 +1,7 @@ from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel from flytekit import Resources, task -from flytekit.core.context_manager import EntrypointSettings -from flytekit.extend import Image, ImageConfig, SerializationSettings +from flytekit.configuration import EntrypointSettings, Image, ImageConfig, SerializationSettings def test_mpi_model_task(): @@ -38,7 +37,6 @@ def my_mpi_task(x: int, y: str) -> int: version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), - entrypoint_settings=EntrypointSettings(path="/etc/my-entrypoint", command="my-entrypoint"), ) assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index c72f615307..4b0bde78b0 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -8,7 +8,8 @@ from google.protobuf.json_format import MessageToDict from flytekit import PythonFunctionTask -from flytekit.extend import SerializationSettings, TaskPlugins +from flytekit.configuration import SerializationSettings +from flytekit.extend import TaskPlugins from .models import PyTorchJob diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 6d5db78bfd..00eb6c0953 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -1,7 +1,7 @@ from flytekitplugins.kfpytorch.task import PyTorch from flytekit import Resources, task -from flytekit.extend import Image, ImageConfig, SerializationSettings +from flytekit.configuration import Image, ImageConfig, SerializationSettings def test_pytorch_task(): diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index f8b767a631..03855e3095 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -8,7 +8,8 @@ from google.protobuf.json_format import MessageToDict from flytekit import PythonFunctionTask -from flytekit.extend import SerializationSettings, TaskPlugins +from flytekit.configuration import SerializationSettings +from flytekit.extend import TaskPlugins from .models import TensorFlowJob diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index 2bdaf747d5..2bcfcda550 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -1,7 +1,7 @@ from flytekitplugins.kftensorflow import TfJob from flytekit import Resources, task -from flytekit.extend import Image, ImageConfig, SerializationSettings +from flytekit.configuration import Image, ImageConfig, SerializationSettings def test_tensorflow_task(): diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 7652fcfb1f..534acb978e 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from typing import Dict, Optional, Type -from flytekit.extend import SerializationSettings, SQLTask +from flytekit.configuration import SerializationSettings +from flytekit.extend import SQLTask from flytekit.models import task as _task_model from flytekit.types.schema import FlyteSchema diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index a63e5c195e..ab558ca534 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -4,7 +4,8 @@ from flytekitplugins.snowflake import SnowflakeConfig, SnowflakeTask from flytekit import kwtypes, workflow -from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable from flytekit.types.schema import FlyteSchema query_template = """ diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index d239632248..d309fc71f8 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -1,7 +1,7 @@ -from flytekit.configuration.sdk import USE_STRUCTURED_DATASET +from flytekit.configuration import internal as _internal from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa -from .task import Spark, new_spark_session +from .task import Spark, new_spark_session # noqa -if USE_STRUCTURED_DATASET.get(): +if _internal.LocalSDK.USE_STRUCTURED_DATASET.read(): from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index b59fb3aed5..8428b492ce 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -7,8 +7,9 @@ from pyspark.sql import SparkSession from flytekit import FlyteContextManager, PythonFunctionTask +from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ExecutionParameters -from flytekit.extend import ExecutionState, SerializationSettings, TaskPlugins +from flytekit.extend import ExecutionState, TaskPlugins from .models import SparkJob, SparkType @@ -94,7 +95,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job = SparkJob( spark_conf=self.task_config.spark_conf, hadoop_conf=self.task_config.hadoop_conf, - application_file="local://" + settings.entrypoint_settings.path if settings.entrypoint_settings else "", + application_file="local://" + settings.entrypoint_settings.path, executor_path=settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, diff --git a/plugins/flytekit-spark/tests/test_remote_register.py b/plugins/flytekit-spark/tests/test_remote_register.py index 67d0f63b1f..41798e9ecd 100644 --- a/plugins/flytekit-spark/tests/test_remote_register.py +++ b/plugins/flytekit-spark/tests/test_remote_register.py @@ -2,12 +2,11 @@ from mock import MagicMock, patch from flytekit import task +from flytekit.configuration import Config, SerializationSettings from flytekit.remote.remote import FlyteRemote -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_spark_template_with_remote(mock_insecure, mock_url): +def test_spark_template_with_remote(): @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: return 10 @@ -16,24 +15,30 @@ def my_spark(a: str) -> int: def my_python_task(a: str) -> int: return 10 - mock_url.get.return_value = "localhost" + remote = FlyteRemote( + config=Config.for_endpoint(endpoint="localhost", insecure=True), default_project="p1", default_domain="d1" + ) - mock_insecure.get.return_value = True mock_client = MagicMock() - - remote = FlyteRemote.from_config("p1", "d1") - - remote._image_config = MagicMock() remote._client = mock_client - remote.register(my_spark) + remote.register_task( + my_spark, + serialization_settings=SerializationSettings( + image_config=MagicMock(), + ), + version="v1", + ) serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] + print(serialized_spec) # Check if the serialized spark task has mainApplicaitonFile field set. assert serialized_spec.template.custom["mainApplicationFile"] assert serialized_spec.template.custom["sparkConf"] - remote.register(my_python_task) + remote.register_task( + my_python_task, serialization_settings=SerializationSettings(image_config=MagicMock()), version="v1" + ) serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] # Check if the serialized python task has no mainApplicaitonFile field set by default. diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index dfda716a11..8d7b761126 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -3,8 +3,8 @@ import flytekit from flytekit import task +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.context_manager import ExecutionParameters -from flytekit.extend import Image, ImageConfig, SerializationSettings def test_spark_task(): diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index a04f143e58..42ef2160f8 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -5,8 +5,8 @@ from sqlalchemy import create_engine # type: ignore from flytekit import current_context, kwtypes +from flytekit.configuration import SerializationSettings from flytekit.core.base_sql_task import SQLTask -from flytekit.core.context_manager import SerializationSettings from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor from flytekit.models import task as task_models diff --git a/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py b/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py index 93104eabdd..79b59f9974 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py +++ b/plugins/flytekit-sqlalchemy/tests/test_sql_tracker.py @@ -1,7 +1,8 @@ from collections import OrderedDict +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig from flytekit.tools.translator import get_serializable from .test_task import tk as not_tk @@ -13,7 +14,7 @@ def test_sql_lhs(): def test_sql_command(): default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/pull_request_template.md b/pull_request_template.md index d2becf38b7..7b96078dcf 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -18,9 +18,9 @@ _Please replace this text with a description of what this PR accomplishes._ _How did you fix the bug, make the feature etc. Link to any design docs etc_ ## Tracking Issue -https://github.com/lyft/flyte/issues/ +https://github.com/flyteorg/flyte/issues/ ## Follow-up issue _NA_ OR -_https://github.com/lyft/flyte/issues/_ +_https://github.com/flyteorg/flyte/issues/_ diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 0ca4f26b1e..f5036e56b5 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -9,6 +9,7 @@ import pytest from flytekit import kwtypes +from flytekit.configuration import Config from flytekit.core.launch_plan import LaunchPlan from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task @@ -50,14 +51,14 @@ def test_client(flyteclient, flyte_workflows_register, docker_services): def test_fetch_execute_launch_plan(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}, wait=True) assert execution.outputs["o0"] == "hello world" def fetch_execute_launch_plan_with_args(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.basic_workflow.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 10, "b": "foobar"}, wait=True) assert execution.node_executions["n0"].inputs == {"a": 10} @@ -75,7 +76,7 @@ def fetch_execute_launch_plan_with_args(flyteclient, flyte_workflows_register): def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte_remote_env): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}) @@ -111,7 +112,8 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte def test_fetch_execute_launch_plan_with_subworkflows(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") + flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.subworkflows.parent_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 101}, wait=True) # check node execution inputs and outputs @@ -127,7 +129,8 @@ def test_fetch_execute_launch_plan_with_subworkflows(flyteclient, flyte_workflow def test_fetch_execute_launch_plan_with_child_workflows(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") + flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.child_workflow.parent_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 3}, wait=True) @@ -141,7 +144,7 @@ def test_fetch_execute_launch_plan_with_child_workflows(flyteclient, flyte_workf def test_fetch_execute_workflow(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_workflow = remote.fetch_workflow(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_workflow, {}, wait=True) assert execution.outputs["o0"] == "hello world" @@ -153,7 +156,7 @@ def test_fetch_execute_workflow(flyteclient, flyte_workflows_register): def test_fetch_execute_task(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_task = remote.fetch_task(name="workflows.basic.basic_workflow.t1", version=f"v{VERSION}") execution = remote.execute(flyte_task, {"a": 10}, wait=True) assert execution.outputs["t1_int_output"] == 12 @@ -169,7 +172,7 @@ def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote # make sure the task name is the same as the name used during registration t1._name = t1.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute(t1, inputs={"a": 10}, version=f"v{VERSION}", wait=True) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" @@ -182,7 +185,7 @@ def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_re # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute(my_wf, inputs={"a": 10, "b": "xyz"}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 12 assert execution.outputs["o1"] == "xyzworld" @@ -198,7 +201,7 @@ def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_re def test_fetch_execute_launch_plan_list_of_floats(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.list_float_wf.my_wf", version=f"v{VERSION}") xs: typing.List[float] = [42.24, 999.1, 0.0001] execution = remote.execute(flyte_launch_plan, inputs={"xs": xs}, wait=True) @@ -206,7 +209,7 @@ def test_fetch_execute_launch_plan_list_of_floats(flyteclient, flyte_workflows_r def test_fetch_execute_task_list_of_floats(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_task = remote.fetch_task(name="workflows.basic.list_float_wf.concat_list", version=f"v{VERSION}") xs: typing.List[float] = [0.1, 0.2, 0.3, 0.4, -99999.7] execution = remote.execute(flyte_task, {"xs": xs}, wait=True) @@ -214,7 +217,7 @@ def test_fetch_execute_task_list_of_floats(flyteclient, flyte_workflows_register def test_fetch_execute_task_convert_dict(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_task = remote.fetch_task(name="workflows.basic.dict_str_wf.convert_to_string", version=f"v{VERSION}") d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} execution = remote.execute(flyte_task, {"d": d}, wait=True) @@ -228,7 +231,7 @@ def test_execute_python_workflow_dict_of_string_to_string(flyteclient, flyte_wor # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") d: typing.Dict[str, str] = {"k1": "v1", "k2": "v2"} execution = remote.execute(my_wf, inputs={"d": d}, version=f"v{VERSION}", wait=True) assert json.loads(execution.outputs["o0"]) == {"k1": "v1", "k2": "v2"} @@ -246,7 +249,7 @@ def test_execute_python_workflow_list_of_floats(flyteclient, flyte_workflows_reg # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") xs: typing.List[float] = [42.24, 999.1, 0.0001] execution = remote.execute(my_wf, inputs={"xs": xs}, version=f"v{VERSION}", wait=True) @@ -258,7 +261,7 @@ def test_execute_python_workflow_list_of_floats(flyteclient, flyte_workflows_reg def test_execute_sqlite3_task(flyteclient, flyte_workflows_register, flyte_remote_env): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") example_db = "https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip" interactive_sql_task = SQLite3Task( @@ -281,7 +284,7 @@ def test_execute_sqlite3_task(flyteclient, flyte_workflows_register, flyte_remot def test_execute_joblib_workflow(flyteclient, flyte_workflows_register, flyte_remote_env): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_workflow = remote.fetch_workflow(name="workflows.basic.joblib.joblib_workflow", version=f"v{VERSION}") input_obj = [1, 2, 3] execution = remote.execute(flyte_workflow, {"obj": input_obj}, wait=True) @@ -298,7 +301,7 @@ def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, # make sure the task name is the same as the name used during registration parent_wf._name = parent_wf.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute(parent_wf, {"a": 101}, version=f"v{VERSION}", wait=True) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} @@ -313,6 +316,6 @@ def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, def test_fetch_not_exist_launch_plan(flyteclient): - remote = FlyteRemote.from_config(PROJECT, "development") + remote = FlyteRemote(Config.auto(), PROJECT, "development") with pytest.raises(FlyteEntityNotExistException): remote.fetch_launch_plan(name="workflows.basic.list_float_wf.fake_wf", version=f"v{VERSION}") diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index ced8c716c5..268c838c51 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -6,51 +6,13 @@ from flyteidl.admin.workflow_pb2 import WorkflowSpec import flytekit -from flytekit.clis.sdk_in_container import package, pyflyte, serialize +import flytekit.configuration +import flytekit.tools.serialize_helpers +from flytekit.clis.sdk_in_container import package, pyflyte from flytekit.core import context_manager from flytekit.exceptions.user import FlyteValidationException -def test_validate_image(): - ic = package.validate_image(None, "image", ()) - assert ic - assert ic.default_image is None - - img1 = "xyz:latest" - img2 = "docker.io/xyz:latest" - img3 = "docker.io/xyz:latest" - img3_cli = f"default={img3}" - img4 = "docker.io/my:azb" - img4_cli = f"my_img={img4}" - - ic = package.validate_image(None, "image", (img1,)) - assert ic - assert ic.default_image.full == img1 - - ic = package.validate_image(None, "image", (img2,)) - assert ic - assert ic.default_image.full == img2 - - ic = package.validate_image(None, "image", (img3_cli,)) - assert ic - assert ic.default_image.full == img3 - - with pytest.raises(click.BadParameter): - package.validate_image(None, "image", (img1, img3_cli)) - - with pytest.raises(click.BadParameter): - package.validate_image(None, "image", (img1, img2)) - - with pytest.raises(click.BadParameter): - package.validate_image(None, "image", (img1, img1)) - - ic = package.validate_image(None, "image", (img3_cli, img4_cli)) - assert ic - assert ic.default_image.full == img3 - assert len(ic.images) == 1 - assert ic.images[0].full == img4 - - @flytekit.task def foo(): pass @@ -63,17 +25,17 @@ def wf(): def test_get_registrable_entities(): ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="p", domain="d", version="v", - image_config=context_manager.ImageConfig( - default_image=context_manager.Image("def", "docker.io/def", "latest") + image_config=flytekit.configuration.ImageConfig( + default_image=flytekit.configuration.Image("def", "docker.io/def", "latest") ), ) ) context_manager.FlyteEntities.entities = [foo, wf, "str"] - entities = serialize.get_registrable_entities(ctx) + entities = flytekit.tools.serialize_helpers.get_registrable_entities(ctx) assert entities assert len(entities) == 3 @@ -114,12 +76,12 @@ def wf_2(): return t_1() ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="p", domain="d", version="v", - image_config=context_manager.ImageConfig( - default_image=context_manager.Image("def", "docker.io/def", "latest") + image_config=flytekit.configuration.ImageConfig( + default_image=flytekit.configuration.Image("def", "docker.io/def", "latest") ), ) ) @@ -130,7 +92,7 @@ def wf_2(): FlyteValidationException, match=r"Multiple definitions of the following tasks were found: \['pyflyte.test_package.t_1'\]", ): - serialize.get_registrable_entities(ctx) + flytekit.tools.serialize_helpers.get_registrable_entities(ctx) def test_package(): diff --git a/tests/flytekit/unit/cli/test_flyte_cli.py b/tests/flytekit/unit/cli/test_flyte_cli.py index 7c5830b728..e3e4dd9bb3 100644 --- a/tests/flytekit/unit/cli/test_flyte_cli.py +++ b/tests/flytekit/unit/cli/test_flyte_cli.py @@ -69,40 +69,3 @@ def test_activate_project(mock_client): result = runner.invoke(_main._flyte_cli, ["activate-project", "-p", "foo", "-h", "a.b.com", "-i"]) assert result.exit_code == 0 mock_client().update_project.assert_called_with(_Project.active_project("foo")) - - -@_responses.activate -def test_setup_config_secure_mode(): - runner = _CliRunner() - data = { - "client_id": "123abc123", - "redirect_uri": "http://localhost:53593/callback", - "scopes": ["scope_1", "scope_2"], - "authorization_metadata_key": "fake_key", - } - _responses.add(_responses.GET, "https://flyte.company.com/config/v1/flyte_client", json=data, status=200) - with _mock.patch("configparser.ConfigParser.write"): - result = runner.invoke(_main._flyte_cli, ["setup-config", "-h", "flyte.company.com"]) - assert result.exit_code == 0 - - -@_responses.activate -def test_setup_config_insecure_mode(): - runner = _CliRunner() - - _responses.add(_responses.GET, "http://flyte.company.com/config/v1/flyte_client", json={}, status=200) - with _mock.patch("configparser.ConfigParser.write"): - result = runner.invoke(_main._flyte_cli, ["setup-config", "-h", "flyte.company.com", "-i"]) - assert result.exit_code == 0 - - -def test_flyte_cli(): - runner = _CliRunner() - result = runner.invoke(_main._flyte_cli, ["-c", "~/.flyte/config", "activate-project", "-i"]) - assert "Config file not found at ~/.flyte/config" in result.output - with _mock.patch("os.path.exists") as mock_exists: - result = runner.invoke(_main._flyte_cli, ["activate-project", "-p", "foo", "-i"]) - assert "Using default config file at" in result.output - mock_exists.return_value = True - result = runner.invoke(_main._flyte_cli, ["-c", "~/.flyte/config", "activate-project", "-i"]) - assert "Using config file at ~/.flyte/config" in result.output diff --git a/tests/flytekit/unit/clients/test_friendly.py b/tests/flytekit/unit/clients/test_friendly.py index e2e147dc1d..d45320cef0 100644 --- a/tests/flytekit/unit/clients/test_friendly.py +++ b/tests/flytekit/unit/clients/test_friendly.py @@ -2,12 +2,13 @@ from flyteidl.admin import project_pb2 as _project_pb2 from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient +from flytekit.configuration import PlatformConfig from flytekit.models.project import Project as _Project @_mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.update_project") def test_update_project(mock_raw_update_project): - client = _SynchronousFlyteClient(url="a.b.com", insecure=True) + client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True)) project = _Project("foo", "name", "description", state=_Project.ProjectState.ACTIVE) client.update_project(project) mock_raw_update_project.assert_called_with(project.to_flyte_idl()) @@ -15,7 +16,7 @@ def test_update_project(mock_raw_update_project): @_mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.list_projects") def test_list_projects_paginated(mock_raw_list_projects): - client = _SynchronousFlyteClient(url="a.b.com", insecure=True) + client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True)) client.list_projects_paginated(limit=100, token="") project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None) mock_raw_list_projects.assert_called_with(project_list_request=project_list_request) diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index 86f5a04a6b..646b5bea85 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -1,5 +1,4 @@ import json -import os from subprocess import CompletedProcess import mock @@ -8,17 +7,8 @@ from flyteidl.service import auth_pb2 from mock import MagicMock, patch -from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient -from flytekit.clients.raw import ( - _get_refresh_handler, - _refresh_credentials_basic, - _refresh_credentials_from_command, - _refresh_credentials_standard, - get_basic_authorization_header, - get_secret, - get_token, -) -from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET +from flytekit.clients.raw import RawSynchronousFlyteClient, get_basic_authorization_header, get_token +from flytekit.configuration import AuthType, PlatformConfig def get_admin_stub_mock() -> mock.MagicMock: @@ -45,43 +35,38 @@ def get_admin_stub_mock() -> mock.MagicMock: @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") -@mock.patch("flytekit.clients.raw._insecure_channel") -@mock.patch("flytekit.clients.raw._secure_channel") +@mock.patch("flytekit.clients.raw.grpc.insecure_channel") +@mock.patch("flytekit.clients.raw.grpc.secure_channel") def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() - client = _RawSynchronousFlyteClient(url="a.b.com", insecure=True) + client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) client.set_access_token("abc") assert client._metadata[0][1] == "Bearer abc" assert client.check_access_token("abc") -@mock.patch("flytekit.configuration.creds.COMMAND.get") @mock.patch("subprocess.run") -def test_refresh_credentials_from_command(mock_call_to_external_process, mock_command_from_config): +def test_refresh_credentials_from_command(mock_call_to_external_process): command = ["command", "generating", "token"] token = "token" - mock_command_from_config.return_value = command mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token) - mock_client = mock.MagicMock() - _refresh_credentials_from_command(mock_client) + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS, command=command)) + cc._refresh_credentials_from_command() mock_call_to_external_process.assert_called_with(command, capture_output=True, text=True, check=True) - mock_client.set_access_token.assert_called_with(token) -@mock.patch("flytekit.configuration.creds.SCOPES.get") -@mock.patch("flytekit.clients.raw.get_secret") @mock.patch("flytekit.clients.raw.get_basic_authorization_header") @mock.patch("flytekit.clients.raw.get_token") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") -@mock.patch("flytekit.clients.raw._insecure_channel") -@mock.patch("flytekit.clients.raw._secure_channel") +@mock.patch("flytekit.clients.raw.grpc.insecure_channel") +@mock.patch("flytekit.clients.raw.grpc.secure_channel") def test_refresh_client_credentials_aka_basic( mock_secure_channel, mock_channel, @@ -89,11 +74,7 @@ def test_refresh_client_credentials_aka_basic( mock_admin_auth, mock_get_token, mock_get_basic_header, - mock_secret, - mock_scopes, ): - mock_secret.return_value = "sosecret" - mock_scopes.return_value = ["a", "b", "c", "d"] mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True @@ -101,10 +82,14 @@ def test_refresh_client_credentials_aka_basic( mock_get_token.return_value = ("token1", 1234567) mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() - client = _RawSynchronousFlyteClient(url="a.b.com", insecure=True) + client = RawSynchronousFlyteClient( + PlatformConfig( + endpoint="a.b.com", insecure=True, client_credentials_secret="sosecret", scopes=["a", "b", "c", "d"] + ) + ) client._metadata = None assert not client.check_access_token("fdsa") - _refresh_credentials_basic(client) + client._refresh_credentials_basic() # Scopes from configuration take precendence. mock_get_token.assert_called_once_with("https://your.domain.io/oauth2/token", "Basic 123", "a,b,c,d") @@ -113,41 +98,52 @@ def test_refresh_client_credentials_aka_basic( assert client._metadata[0][0] == "authorization" -def test_raises(): - mm = MagicMock() - mm.public_client_config = None - with pytest.raises(ValueError): - _refresh_credentials_basic(mm) +@mock.patch("flytekit.clients.raw.auth_service") +@mock.patch("flytekit.clients.raw._admin_service") +@mock.patch("flytekit.clients.raw.grpc.insecure_channel") +@mock.patch("flytekit.clients.raw.grpc.secure_channel") +def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth): + mock_secure_channel.return_value = True + mock_channel.return_value = True + mock_admin.AdminServiceStub.return_value = True - mm = MagicMock() - mm.oauth2_metadata = None + # If the public client config is missing then raise an error + mocked_auth = get_admin_stub_mock() + mocked_auth.GetPublicClientConfig.return_value = None + mock_admin_auth.AuthMetadataServiceStub.return_value = mocked_auth + client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) + assert client.public_client_config is None + with pytest.raises(ValueError): + client._refresh_credentials_basic() + + # If the oauth2 metadata is missing then raise an error + mocked_auth = get_admin_stub_mock() + mocked_auth.GetOAuth2Metadata.return_value = None + mock_admin_auth.AuthMetadataServiceStub.return_value = mocked_auth + client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) + assert client.oauth2_metadata is None with pytest.raises(ValueError): - _refresh_credentials_basic(mm) + client._refresh_credentials_basic() @mock.patch("flytekit.clients.raw._admin_service") -@mock.patch("flytekit.clients.raw._insecure_channel") +@mock.patch("flytekit.clients.raw.grpc.insecure_channel") def test_update_project(mock_channel, mock_admin): - client = _RawSynchronousFlyteClient(url="a.b.com", insecure=True) + client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) project = _project_pb2.Project(id="foo", name="name", description="description", state=_project_pb2.Project.ACTIVE) client.update_project(project) mock_admin.AdminServiceStub().UpdateProject.assert_called_with(project, metadata=None) @mock.patch("flytekit.clients.raw._admin_service") -@mock.patch("flytekit.clients.raw._insecure_channel") +@mock.patch("flytekit.clients.raw.grpc.insecure_channel") def test_list_projects_paginated(mock_channel, mock_admin): - client = _RawSynchronousFlyteClient(url="a.b.com", insecure=True) + client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None) client.list_projects(project_list_request) mock_admin.AdminServiceStub().ListProjects.assert_called_with(project_list_request, metadata=None) -def test_get_secret(): - os.environ[_CREDENTIALS_SECRET.env_var] = "abc" - assert get_secret() == "abc" - - def test_get_basic_authorization_header(): header = get_basic_authorization_header("client_id", "abc") assert header == "Basic Y2xpZW50X2lkOmFiYw==" @@ -164,10 +160,26 @@ def test_get_token(mock_requests): assert expiration == 60 -def test_get_refresh_handler(): - cc = _get_refresh_handler("client_credentials") - basic = _get_refresh_handler("basic") - assert basic is cc - assert basic is _refresh_credentials_basic - standard = _get_refresh_handler("standard") - assert standard is _refresh_credentials_standard +@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_standard") +def test_refresh_standard(mocked_method): + cc = RawSynchronousFlyteClient(PlatformConfig()) + cc.refresh_credentials() + assert mocked_method.called + + +@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_basic") +def test_refresh_basic(mocked_method): + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.BASIC)) + cc.refresh_credentials() + assert mocked_method.called + + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS)) + cc.refresh_credentials() + assert mocked_method.called + + +@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command") +def test_refresh_basic(mocked_method): + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS)) + cc.refresh_credentials() + assert mocked_method.called diff --git a/tests/flytekit/unit/configuration/configs/good.config b/tests/flytekit/unit/configuration/configs/good.config index 431eb0ed1f..70d14b6afa 100644 --- a/tests/flytekit/unit/configuration/configs/good.config +++ b/tests/flytekit/unit/configuration/configs/good.config @@ -14,6 +14,8 @@ url=fakeflyte.com int_value=3 string_value=abc +bool_value=false +timedelta_value=20h [resources] default_cpu_request=500m diff --git a/tests/flytekit/unit/configuration/conftest.py b/tests/flytekit/unit/configuration/conftest.py deleted file mode 100644 index 700805ee96..0000000000 --- a/tests/flytekit/unit/configuration/conftest.py +++ /dev/null @@ -1,14 +0,0 @@ -import os as _os - -import pytest as _pytest - -from flytekit.configuration import set_flyte_config_file as _set_config - - -@_pytest.fixture(scope="function", autouse=True) -def clear_configs(): - _set_config(None) - environment_variables = _os.environ.copy() - yield - _os.environ = environment_variables - _set_config(None) diff --git a/tests/flytekit/unit/configuration/test_common.py b/tests/flytekit/unit/configuration/test_common.py deleted file mode 100644 index a9088bdba6..0000000000 --- a/tests/flytekit/unit/configuration/test_common.py +++ /dev/null @@ -1,34 +0,0 @@ -import os - -import pytest - -from flytekit.configuration import common, set_flyte_config_file - - -def test_file_loader_bad(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/bad.config")) - with pytest.raises(Exception): - common.CONFIGURATION_SINGLETON.get_string("a", "b") - - -def test_file_loader_good(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) - assert common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" - assert common.CONFIGURATION_SINGLETON.get_string("auth", "assumable_iam_role") == "some_role" - - -def test_env_var_precedence_string(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) - - assert common.FlyteIntegerConfigurationEntry("madeup", "int_value").get() == 3 - assert common.FlyteStringConfigurationEntry("madeup", "string_value").get() == "abc" - - old_environ = dict(os.environ) - try: - os.environ["FLYTE_MADEUP_INT_VALUE"] = "10" - os.environ["FLYTE_MADEUP_STRING_VALUE"] = "overridden" - assert common.FlyteIntegerConfigurationEntry("madeup", "int_value").get() == 10 - assert common.FlyteStringConfigurationEntry("madeup", "string_value").get() == "overridden" - finally: - os.environ.clear() - os.environ.update(old_environ) diff --git a/tests/flytekit/unit/configuration/test_file.py b/tests/flytekit/unit/configuration/test_file.py new file mode 100644 index 0000000000..c7261fff1b --- /dev/null +++ b/tests/flytekit/unit/configuration/test_file.py @@ -0,0 +1,118 @@ +import configparser +import datetime +import os + +import mock +import pytest +from pytimeparse.timeparse import timeparse + +from flytekit.configuration import ConfigEntry, get_config_file, set_if_exists +from flytekit.configuration.file import LegacyConfigEntry + + +def test_set_if_exists(): + d = {} + d = set_if_exists(d, "k", None) + assert len(d) == 0 + d = set_if_exists(d, "k", []) + assert len(d) == 0 + d = set_if_exists(d, "k", "x") + assert len(d) == 1 + assert d["k"] == "x" + + +def test_get_config_file(): + c = get_config_file(None) + assert c is None + c = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + assert c is not None + assert c.legacy_config is not None + + with pytest.raises(configparser.Error): + get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/bad.config")) + + +def test_config_entry_envvar(): + # Pytest feature + c = ConfigEntry(LegacyConfigEntry("test", "op1", str)) + assert c.read() is None + + old_environ = dict(os.environ) + os.environ["FLYTE_TEST_OP1"] = "xyz" + assert c.read() == "xyz" + os.environ = old_environ + + +def test_config_entry_file(): + # Pytest feature + c = ConfigEntry(LegacyConfigEntry("platform", "url", str)) + assert c.read() is None + + cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + assert c.read(cfg) == "fakeflyte.com" + + c = ConfigEntry(LegacyConfigEntry("platform", "url2", str)) # Does not exist + assert c.read(cfg) is None + + +def test_config_entry_precedence(): + # Pytest feature + c = ConfigEntry(LegacyConfigEntry("platform", "url", str)) + assert c.read() is None + + old_environ = dict(os.environ) + os.environ["FLYTE_PLATFORM_URL"] = "xyz" + cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + assert c.read(cfg) == "xyz" + # reset + os.environ = old_environ + + +def test_config_entry_types(): + cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + + l = ConfigEntry(LegacyConfigEntry("sdk", "workflow_packages", list)) + assert l.read(cfg) == ["this.module", "that.module"] + + s = ConfigEntry(LegacyConfigEntry("madeup", "string_value")) + assert s.read(cfg) == "abc" + + i = ConfigEntry(LegacyConfigEntry("madeup", "int_value", int)) + assert i.read(cfg) == 3 + + b = ConfigEntry(LegacyConfigEntry("madeup", "bool_value", bool)) + assert b.read(cfg) is False + + t = ConfigEntry( + LegacyConfigEntry("madeup", "timedelta_value", datetime.timedelta), + transform=lambda x: datetime.timedelta(seconds=timeparse(x)), + ) + assert t.read(cfg) == datetime.timedelta(hours=20) + + +@mock.patch("flytekit.configuration.file.LegacyConfigEntry.read_from_file") +def test_env_var_bool_transformer(mock_file_read): + mock_file_read.return_value = None + test_env_var = "FLYTE_MADEUP_TEST_VAR_ABC123" + b = ConfigEntry(LegacyConfigEntry("madeup", "test_var_abc123", bool)) + + os.environ[test_env_var] = "FALSE" + assert b.read() is False + + os.environ[test_env_var] = "" + assert b.read() is False + + os.environ[test_env_var] = "1" + assert b.read() + + os.environ[test_env_var] = "truee" + assert b.read() + # The above reads shouldn't have triggered the file read since the env var was set + assert mock_file_read.call_count == 0 + + del os.environ[test_env_var] + + assert b.read() is None + + # The last read should've triggered the file read since now the env var is no longer set. + assert mock_file_read.call_count == 1 diff --git a/tests/flytekit/unit/configuration/test_images.py b/tests/flytekit/unit/configuration/test_images.py deleted file mode 100644 index 7a475fd64a..0000000000 --- a/tests/flytekit/unit/configuration/test_images.py +++ /dev/null @@ -1,15 +0,0 @@ -import os - -from flytekit.configuration import images, set_flyte_config_file - - -def test_load_images(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) - imgs = images.get_specified_images() - assert imgs == {"abc": "docker.io/abc", "xyz": "docker.io/xyz:latest"} - - -def test_no_images(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) - imgs = images.get_specified_images() - assert imgs == {} diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 73cbf62b49..accaa0fdd5 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -1,13 +1,16 @@ -import pytest +import os -from flytekit.configuration.internal import look_up_version_from_image_tag +from flytekit.configuration import get_config_file +from flytekit.configuration.internal import Images -def test_parsing(): - str = "somedocker.com/myimage:someversion123" - version = look_up_version_from_image_tag(str) - assert version == "someversion123" +def test_load_images(): + cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) + imgs = Images.get_specified_images(cfg) + assert imgs == {"abc": "docker.io/abc", "xyz": "docker.io/xyz:latest"} - str = "ffjdskl/jfkljkdfls" - with pytest.raises(Exception): - look_up_version_from_image_tag(str) + +def test_no_images(): + cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + imgs = Images.get_specified_images(cfg) + assert imgs == {} diff --git a/tests/flytekit/unit/configuration/test_resources.py b/tests/flytekit/unit/configuration/test_resources.py deleted file mode 100644 index 3564ebe363..0000000000 --- a/tests/flytekit/unit/configuration/test_resources.py +++ /dev/null @@ -1,26 +0,0 @@ -import os - -from flytekit.configuration import resources, set_flyte_config_file - - -def test_resource_hints_default(): - assert resources.DEFAULT_CPU_LIMIT.get() is None - assert resources.DEFAULT_CPU_REQUEST.get() is None - assert resources.DEFAULT_MEMORY_REQUEST.get() is None - assert resources.DEFAULT_MEMORY_LIMIT.get() is None - assert resources.DEFAULT_GPU_REQUEST.get() is None - assert resources.DEFAULT_GPU_LIMIT.get() is None - assert resources.DEFAULT_STORAGE_REQUEST.get() is None - assert resources.DEFAULT_STORAGE_LIMIT.get() is None - - -def test_resource_hints(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) - assert resources.DEFAULT_CPU_REQUEST.get() == "500m" - assert resources.DEFAULT_CPU_LIMIT.get() == "501m" - assert resources.DEFAULT_MEMORY_REQUEST.get() == "500Gi" - assert resources.DEFAULT_MEMORY_LIMIT.get() == "501Gi" - assert resources.DEFAULT_GPU_REQUEST.get() == "1" - assert resources.DEFAULT_GPU_LIMIT.get() == "2" - assert resources.DEFAULT_STORAGE_REQUEST.get() == "500Gi" - assert resources.DEFAULT_STORAGE_LIMIT.get() == "501Gi" diff --git a/tests/flytekit/unit/configuration/test_temporary_configuration.py b/tests/flytekit/unit/configuration/test_temporary_configuration.py deleted file mode 100644 index fe51f46d06..0000000000 --- a/tests/flytekit/unit/configuration/test_temporary_configuration.py +++ /dev/null @@ -1,34 +0,0 @@ -import os as _os - -from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration -from flytekit.configuration import common as _common -from flytekit.configuration import set_flyte_config_file as _set_flyte_config_file - - -def test_configuration_file(): - with _TemporaryConfiguration(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config")): - assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" - assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") is None - - -def test_internal_overrides(): - with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config"), - {"foo": "bar"}, - ): - assert _os.environ.get("FLYTE_INTERNAL_FOO") == "bar" - assert _os.environ.get("FLYTE_INTERNAL_FOO") is None - - -def test_no_configuration_file(): - _set_flyte_config_file(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config")) - with _TemporaryConfiguration(None): - assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") is None - assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" - - -def test_nonexist_configuration_file(): - _set_flyte_config_file(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config")) - with _TemporaryConfiguration("/foo/bar"): - assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") is None - assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" diff --git a/tests/flytekit/unit/configuration/test_waterfall.py b/tests/flytekit/unit/configuration/test_waterfall.py deleted file mode 100644 index a4cd749a80..0000000000 --- a/tests/flytekit/unit/configuration/test_waterfall.py +++ /dev/null @@ -1,44 +0,0 @@ -import os as _os - -from flytekit.configuration import common as _common -from flytekit.core.utils import AutoDeletingTempDir as _AutoDeletingTempDir - - -def test_lookup_waterfall_raw_env_var(): - x = _common.FlyteStringConfigurationEntry("test", "setting", default=None) - - if "FLYTE_TEST_SETTING" in _os.environ: - del _os.environ["FLYTE_TEST_SETTING"] - assert x.get() is None - - _os.environ["FLYTE_TEST_SETTING"] = "lorem" - assert x.get() == "lorem" - - -def test_lookup_waterfall_referenced_env_var(): - x = _common.FlyteStringConfigurationEntry("test", "setting", default=None) - - if "FLYTE_TEST_SETTING" in _os.environ: - del _os.environ["FLYTE_TEST_SETTING"] - assert x.get() is None - - if "TEMP_PLACEHOLDER" in _os.environ: - del _os.environ["TEMP_PLACEHOLDER"] - _os.environ["TEMP_PLACEHOLDER"] = "lorem" - _os.environ["FLYTE_TEST_SETTING_FROM_ENV_VAR"] = "TEMP_PLACEHOLDER" - assert x.get() == "lorem" - - -def test_lookup_waterfall_referenced_file(): - x = _common.FlyteStringConfigurationEntry("test", "setting", default=None) - - if "FLYTE_TEST_SETTING" in _os.environ: - del _os.environ["FLYTE_TEST_SETTING"] - assert x.get() is None - - with _AutoDeletingTempDir("config_testing") as tmp_dir: - with open(tmp_dir.get_named_tempfile("name"), "w") as fh: - fh.write("secret_password") - - _os.environ["FLYTE_TEST_SETTING_FROM_FILE"] = tmp_dir.get_named_tempfile("name") - assert x.get() == "secret_password" diff --git a/tests/flytekit/unit/core/test_complex_nesting.py b/tests/flytekit/unit/core/test_complex_nesting.py index 13bb108f7c..bf8fbb0e5f 100644 --- a/tests/flytekit/unit/core/test_complex_nesting.py +++ b/tests/flytekit/unit/core/test_complex_nesting.py @@ -6,7 +6,8 @@ import pytest from dataclasses_json import dataclass_json -from flytekit.core.context_manager import ExecutionState, FlyteContextManager, Image, ImageConfig, SerializationSettings +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.type_engine import TypeEngine from flytekit.types.directory import FlyteDirectory @@ -141,10 +142,6 @@ def dt1(a: List[MyInput]) -> List[FlyteFile]: ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, - additional_context={ - "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", - "dynamic_dest_dir": "/User/flyte/workflows", - }, ) ) ) as ctx: @@ -207,8 +204,10 @@ def dt1(a: List[MyInput]) -> List[FlyteDirectory]: return x - with FlyteContextManager.with_context( - FlyteContextManager.current_context().with_serialization_settings( + ctx = FlyteContextManager.current_context() + cb = ( + ctx.new_builder() + .with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", @@ -217,21 +216,12 @@ def dt1(a: List[MyInput]) -> List[FlyteDirectory]: env={}, ) ) - ) as ctx: - with FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params( - mode=ExecutionState.Mode.TASK_EXECUTION, - additional_context={ - "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", - "dynamic_dest_dir": "/User/flyte/workflows", - }, - ) - ) - ) as ctx: - input_literal_map = TypeEngine.dict_to_literal_map( - ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, guessed_python_types={"a": List[MyInput]} - ) - dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) - assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two" - assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four" + .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) + ) + with FlyteContextManager.with_context(cb) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map( + ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, guessed_python_types={"a": List[MyInput]} + ) + dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) + assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two" + assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four" diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index 13cfcf3706..598e3a3eef 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -4,10 +4,11 @@ import mock import pytest +import flytekit.configuration from flytekit import task, workflow +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.condition import conditional -from flytekit.core.context_manager import Image, ImageConfig, SerializationSettings from flytekit.models.core.workflow import Node from flytekit.tools.translator import get_serializable @@ -226,7 +227,7 @@ def nested_branching(x: int) -> int: return conditional("nested test").if_(x == 2).then(ifelse_branching(x=x)).else_().then(wf5()) default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index d237090367..ddf39ba552 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -3,14 +3,15 @@ import py import pytest -from flytekit.configuration import secrets -from flytekit.core.context_manager import ( - ExecutionState, - FlyteContext, - FlyteContextManager, - SecretsManager, - look_up_image_info, +from flytekit.configuration import ( + EntrypointSettings, + FastSerializationSettings, + Image, + ImageConfig, + SecretsConfig, + SerializationSettings, ) +from flytekit.core.context_manager import FlyteContext, FlyteContextManager, SecretsManager class SampleTestClass(object): @@ -39,44 +40,65 @@ def test_default(): def test_look_up_image_info(): - img = look_up_image_info(name="x", tag="docker.io/xyz", optional_tag=True) + img = Image.look_up_image_info(name="x", tag="docker.io/xyz", optional_tag=True) assert img.name == "x" assert img.tag is None assert img.fqn == "docker.io/xyz" - img = look_up_image_info(name="x", tag="docker.io/xyz:latest", optional_tag=True) + img = Image.look_up_image_info(name="x", tag="docker.io/xyz:latest", optional_tag=True) assert img.name == "x" assert img.tag == "latest" assert img.fqn == "docker.io/xyz" - img = look_up_image_info(name="x", tag="docker.io/xyz:latest", optional_tag=False) + img = Image.look_up_image_info(name="x", tag="docker.io/xyz:latest", optional_tag=False) assert img.name == "x" assert img.tag == "latest" assert img.fqn == "docker.io/xyz" - img = look_up_image_info(name="x", tag="localhost:5000/xyz:latest", optional_tag=False) + img = Image.look_up_image_info(name="x", tag="localhost:5000/xyz:latest", optional_tag=False) assert img.name == "x" assert img.tag == "latest" assert img.fqn == "localhost:5000/xyz" -def test_additional_context(): - ctx = FlyteContext.current_context() - with FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params( - mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "outer", 2: "foo"} - ) - ) - ) as exec_ctx_outer: - with FlyteContextManager.with_context( - ctx.with_execution_state( - exec_ctx_outer.execution_state.with_params( - mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "inner", 3: "baz"} - ) - ) - ) as exec_ctx_inner: - assert exec_ctx_inner.execution_state.additional_context == {1: "inner", 2: "foo", 3: "baz"} +def test_validate_image(): + ic = ImageConfig.validate_image(None, "image", ()) + assert ic + assert ic.default_image is None + + img1 = "xyz:latest" + img2 = "docker.io/xyz:latest" + img3 = "docker.io/xyz:latest" + img3_cli = f"default={img3}" + img4 = "docker.io/my:azb" + img4_cli = f"my_img={img4}" + + ic = ImageConfig.validate_image(None, "image", (img1,)) + assert ic + assert ic.default_image.full == img1 + + ic = ImageConfig.validate_image(None, "image", (img2,)) + assert ic + assert ic.default_image.full == img2 + + ic = ImageConfig.validate_image(None, "image", (img3_cli,)) + assert ic + assert ic.default_image.full == img3 + + with pytest.raises(ValueError): + ImageConfig.validate_image(None, "image", (img1, img3_cli)) + + with pytest.raises(ValueError): + ImageConfig.validate_image(None, "image", (img1, img2)) + + with pytest.raises(ValueError): + ImageConfig.validate_image(None, "image", (img1, img1)) + + ic = ImageConfig.validate_image(None, "image", (img3_cli, img4_cli)) + assert ic + assert ic.default_image.full == img3 + assert len(ic.images) == 1 + assert ic.images[0].full == img4 def test_secrets_manager_default(): @@ -94,7 +116,8 @@ def test_secrets_manager_get_envvar(): sec.get_secrets_env_var("test", "") with pytest.raises(ValueError): sec.get_secrets_env_var("", "x") - assert sec.get_secrets_env_var("group", "test") == f"{secrets.SECRETS_ENV_PREFIX.get()}GROUP_TEST" + cfg = SecretsConfig.auto() + assert sec.get_secrets_env_var("group", "test") == f"{cfg.env_prefix}GROUP_TEST" def test_secrets_manager_get_file(): @@ -103,10 +126,11 @@ def test_secrets_manager_get_file(): sec.get_secrets_file("test", "") with pytest.raises(ValueError): sec.get_secrets_file("", "x") + cfg = SecretsConfig.auto() assert sec.get_secrets_file("group", "test") == os.path.join( - secrets.SECRETS_DEFAULT_DIR.get(), + cfg.default_dir, "group", - f"{secrets.SECRETS_FILE_PREFIX.get()}test", + f"{cfg.file_prefix}test", ) @@ -150,3 +174,30 @@ def test_secrets_manager_env(): os.environ[sec.get_secrets_env_var(group="group", key="key")] = "value" assert sec.get(group="group", key="key") == "value" + + +def test_serialization_settings_transport(): + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"hello": "blah"}, + image_config=ImageConfig( + default_image=default_img, + images=[default_img], + ), + flytekit_virtualenv_root="/opt/venv/blah", + python_interpreter="/opt/venv/bin/python3", + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir="/opt/blah/blah/blah", + distribution_location="s3://my-special-bucket/blah/bha/asdasdasd/cbvsdsdf/asdddasdasdasdasdasdasd.tar.gz", + ), + ) + + tp = serialization_settings.prepare_for_transport() + ss = SerializationSettings.from_transport(tp) + assert ss is not None + assert ss == serialization_settings + assert len(tp) == 376 diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 9da8b6b5cd..283db3f357 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -1,8 +1,10 @@ import typing +import flytekit.configuration from flytekit import dynamic +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow @@ -28,13 +30,17 @@ def my_wf(a: int) -> typing.List[str]: with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, - fast_serialization_settings=FastSerializationSettings(enabled=True), + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir="/User/flyte/workflows", + distribution_location="s3://my-s3-bucket/fast/123", + ), ) ) ) as ctx: @@ -42,10 +48,6 @@ def my_wf(a: int) -> typing.List[str]: ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, - additional_context={ - "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", - "dynamic_dest_dir": "/User/flyte/workflows", - }, ) ) ) as ctx: diff --git a/tests/flytekit/unit/core/test_dynamic_conditional.py b/tests/flytekit/unit/core/test_dynamic_conditional.py index a5c381944d..8c34f34759 100644 --- a/tests/flytekit/unit/core/test_dynamic_conditional.py +++ b/tests/flytekit/unit/core/test_dynamic_conditional.py @@ -2,10 +2,12 @@ from datetime import datetime from random import seed +import flytekit.configuration from flytekit import dynamic, task, workflow +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.condition import conditional -from flytekit.core.context_manager import ExecutionState, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState # seed random number generator seed(datetime.now().microsecond) @@ -79,7 +81,7 @@ def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]: with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index 93a4d0e039..0cb4f524f9 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -7,8 +7,10 @@ import pytest +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState, FlyteContextManager, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.task import task @@ -173,7 +175,7 @@ def dyn(in1: FlyteDirectory): ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context( ctx.with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 9132be082c..5209d21507 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -6,8 +6,10 @@ import pytest +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.launch_plan import LaunchPlan @@ -243,7 +245,7 @@ def dyn(in1: FlyteFile): with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index c6134558b4..a3ec6d17ce 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -1,8 +1,9 @@ from collections import OrderedDict from typing import Dict, List +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.task import task from flytekit.models.core.types import BlobType from flytekit.models.literals import BlobMetadata @@ -11,7 +12,7 @@ from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index 2d34466862..bf99f0c137 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -4,9 +4,10 @@ import pandas as pd import pytest +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.base_task import kwtypes -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import reference_task, task from flytekit.core.workflow import ImperativeWorkflow, get_promise, workflow @@ -19,7 +20,7 @@ from flytekit.types.structured.structured_dataset import StructuredDatasetType default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_imperative_with_patching.py b/tests/flytekit/unit/core/test_imperative_with_patching.py index 2db305fa3e..366413463a 100644 --- a/tests/flytekit/unit/core/test_imperative_with_patching.py +++ b/tests/flytekit/unit/core/test_imperative_with_patching.py @@ -1,14 +1,15 @@ import pytest from mock import patch as _system_patch +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.task import task from flytekit.core.testing import patch as flyte_patch from flytekit.core.workflow import ImperativeWorkflow, workflow default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index baa33cd356..4a1cf9d8ba 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -4,8 +4,9 @@ import pytest from flyteidl.admin import launch_plan_pb2 as _launch_plan_idl +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager, launch_plan, notification -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.schedule import CronSchedule from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -15,7 +16,7 @@ from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 4eb44d6e76..a1d04659d6 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -3,9 +3,10 @@ import pytest +import flytekit.configuration from flytekit import LaunchPlan, map_task +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.map_task import MapPythonTask from flytekit.core.task import TaskMetadata, task from flytekit.core.workflow import workflow @@ -15,7 +16,7 @@ @pytest.fixture def serialization_settings(): default_img = Image(name="default", fqn="test", tag="tag") - return context_manager.SerializationSettings( + return flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 811b2d46e5..a722328420 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -4,9 +4,10 @@ import pytest +import flytekit.configuration from flytekit import Resources, map_task +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node from flytekit.core.task import task @@ -39,7 +40,7 @@ def my_wf(a: str) -> (str, typing.List[str]): assert r == "hello world" assert x == ["0 world", "1 world", "2 world"] - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -73,7 +74,7 @@ def empty_wf2(): t3_node = create_node(t3) t3_node >> t2_node - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -183,7 +184,7 @@ def my_wf(a: typing.List[str]) -> typing.List[str]: map_node = mappy(a=a).with_overrides(requests=Resources(cpu="1", mem="100", ephemeral_storage="500Mi")) return map_node - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -212,7 +213,7 @@ def my_wf(a: typing.List[str]) -> typing.List[str]: map_node = mappy(a=a).with_overrides(limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi")) return map_node - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -243,7 +244,7 @@ def my_wf(a: typing.List[str]) -> typing.List[str]: ) return map_node - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -279,7 +280,7 @@ def t1(a: str) -> str: def my_wf(a: str) -> str: return t1(a=a).with_overrides(timeout=timeout) - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -315,7 +316,7 @@ def t1(a: str) -> str: def my_wf(a: str) -> str: return t1(a=a).with_overrides(retries=retries) - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -337,7 +338,7 @@ def t1(a: str) -> str: def my_wf(a: str) -> str: return t1(a=a).with_overrides(interruptible=interruptible) - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 4394710dd5..f7b4f34709 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -2,7 +2,7 @@ import pytest -from flytekit.core.context_manager import Image, ImageConfig, SerializationSettings +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 708bacfe50..34aaefaeb3 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -1,7 +1,7 @@ import pytest from flytekit import task -from flytekit.core.context_manager import Image, ImageConfig, SerializationSettings +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.tracker import isnested, istestfunction diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index 5151539236..df6e093b55 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -3,9 +3,10 @@ import pytest +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.base_task import kwtypes -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan from flytekit.core.promise import VoidPromise @@ -58,7 +59,7 @@ def test_ref(): assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -247,7 +248,7 @@ def test_lps(resource_type): def wf1(a: str, b: int): ref_entity(a=a, b=b) - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -294,7 +295,7 @@ def test_ref_sub_wf(): def wf1(a: str, b: int): ref_entity(a=a, b=b) - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -342,7 +343,7 @@ def inner_test(ref_mock): inner_test() - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -412,7 +413,7 @@ def my_subwf(a: int) -> typing.List[str]: with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -441,7 +442,7 @@ def ref_lp1(p1: str, p2: str) -> int: with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py index ea44099587..2a573b268d 100644 --- a/tests/flytekit/unit/core/test_resolver.py +++ b/tests/flytekit/unit/core/test_resolver.py @@ -3,17 +3,18 @@ import pytest +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.base_task import TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.python_auto_container import default_task_resolver from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 11616203cd..91e2c877ec 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -4,18 +4,18 @@ import pytest +import flytekit.configuration from flytekit import ContainerTask, kwtypes -from flytekit.configuration import set_flyte_config_file -from flytekit.core import context_manager +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional -from flytekit.core.context_manager import Image, ImageConfig, SerializationSettings, get_image_config +from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.models.types import SimpleType from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", @@ -50,7 +50,7 @@ def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", @@ -138,7 +138,7 @@ def my_wf(a: int) -> int: return d default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", @@ -176,7 +176,7 @@ def my_wf(a: int, b: str) -> (int, str): return x, f default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", @@ -210,7 +210,7 @@ def my_wf(a: int) -> str: assert my_wf(a=2) == "hello" default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", @@ -223,51 +223,56 @@ def my_wf(a: int) -> str: assert wf_spec.template.nodes[1].branch_node is not None +def test_bad_configuration(): + container_image = "{{.image.xyz.fqn}}:{{.image.default.version}}" + image_config = ImageConfig.auto( + config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config") + ) + # No default image in the images.config file so nothing to pull version from + with pytest.raises(AssertionError): + get_registerable_container_image(container_image, image_config) + + def test_serialization_images(): - @task(container_image="{{.image.xyz.fqn}}:{{.image.default.version}}") + @task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}") def t1(a: int) -> int: return a - @task(container_image="{{.image.default.fqn}}:{{.image.default.version}}") + @task(container_image="{{.image.abc.fqn}}:{{.image.xyz.version}}") def t2(): pass - @task - def t3(): - pass - @task(container_image="docker.io/org/myimage:latest") def t4(): pass - @task(container_image="docker.io/org/myimage:{{.image.default.version}}") + @task(container_image="docker.io/org/myimage:{{.image.xyz.version}}") def t5(a: int) -> int: return a os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) - rs = context_manager.SerializationSettings( + imgs = ImageConfig.auto( + config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config") + ) + rs = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, - image_config=get_image_config(), + image_config=imgs, ) t1_spec = get_serializable(OrderedDict(), rs, t1) - assert t1_spec.template.container.image == "docker.io/xyz:version" + assert t1_spec.template.container.image == "docker.io/xyz:latest" t1_spec.to_flyte_idl() t2_spec = get_serializable(OrderedDict(), rs, t2) - assert t2_spec.template.container.image == "docker.io/default:version" - - t3_spec = get_serializable(OrderedDict(), rs, t3) - assert t3_spec.template.container.image == "docker.io/default:version" + assert t2_spec.template.container.image == "docker.io/abc:latest" t4_spec = get_serializable(OrderedDict(), rs, t4) assert t4_spec.template.container.image == "docker.io/org/myimage:latest" t5_spec = get_serializable(OrderedDict(), rs, t5) - assert t5_spec.template.container.image == "docker.io/org/myimage:version" + assert t5_spec.template.container.image == "docker.io/org/myimage:latest" def test_serialization_command1(): diff --git a/tests/flytekit/unit/core/test_shim_task.py b/tests/flytekit/unit/core/test_shim_task.py index b22f2e7349..361b451181 100644 --- a/tests/flytekit/unit/core/test_shim_task.py +++ b/tests/flytekit/unit/core/test_shim_task.py @@ -3,15 +3,16 @@ import mock +import flytekit.configuration from flytekit import ContainerTask, kwtypes +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask, TaskTemplateResolver from flytekit.core.utils import write_proto_to_file from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 8efaecfcc9..a567a564fc 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -2,8 +2,10 @@ import pytest +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import FlyteContext, FlyteContextManager, Image, ImageConfig +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata @@ -34,7 +36,7 @@ fields = [("some_int", pa.int32()), ("some_string", pa.string())] arrow_schema = pa.schema(fields) -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="proj", domain="dom", version="123", diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index e5ed338599..672d739532 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -17,10 +17,12 @@ from typing_extensions import Annotated import flytekit +import flytekit.configuration from flytekit import ContainerTask, Secret, SQLTask, dynamic, kwtypes, map_task +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig from flytekit.core import context_manager, launch_plan, promise from flytekit.core.condition import conditional -from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.hash import HashMethod from flytekit.core.node import Node @@ -41,7 +43,7 @@ from flytekit.types.schema import FlyteSchema, SchemaOpenMode from flytekit.types.structured.structured_dataset import StructuredDataset -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="proj", domain="dom", version="123", @@ -564,7 +566,7 @@ def my_wf(a: int, b: str) -> (str, typing.List[str]): with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -602,13 +604,17 @@ def my_wf(a: int) -> typing.List[str]: with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context().with_serialization_settings( - context_manager.SerializationSettings( + flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, - fast_serialization_settings=FastSerializationSettings(enabled=True), + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir="/User/flyte/workflows", + distribution_location="s3://my-s3-bucket/fast/123", + ), ) ) ) as ctx: @@ -616,10 +622,6 @@ def my_wf(a: int) -> typing.List[str]: ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, - additional_context={ - "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", - "dynamic_dest_dir": "/User/flyte/workflows", - }, ) ) ) as ctx: @@ -864,7 +866,7 @@ def my_subwf(a: int) -> (str, str): lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="proj", domain="dom", version="123", @@ -1248,7 +1250,7 @@ def my_wf(a: int) -> str: x = t1(a=a) return x - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -1281,7 +1283,7 @@ def my_wf(a: int) -> str: x = t1(a=a) return x - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", @@ -1396,7 +1398,7 @@ def my_subwf(a: int) -> typing.List[str]: x = my_wf(a=v, b="hello ") assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)]) - settings = context_manager.SerializationSettings( + settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", diff --git a/tests/flytekit/unit/core/test_typing_annotation.py b/tests/flytekit/unit/core/test_typing_annotation.py index f999c62612..a691b2b415 100644 --- a/tests/flytekit/unit/core/test_typing_annotation.py +++ b/tests/flytekit/unit/core/test_typing_annotation.py @@ -3,15 +3,16 @@ import typing_extensions +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.annotation import FlyteAnnotation -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.task import task from flytekit.models.annotation import TypeAnnotation from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index c199a474de..088fccf175 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -1,3 +1,4 @@ +import os import typing from collections import OrderedDict @@ -5,10 +6,10 @@ import pytest from pandas.testing import assert_frame_equal +import flytekit.configuration from flytekit import StructuredDataset, kwtypes -from flytekit.core import context_manager +from flytekit.configuration import Image, ImageConfig from flytekit.core.condition import conditional -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.task import task from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow from flytekit.exceptions.user import FlyteValidationException, FlyteValueException @@ -21,7 +22,7 @@ from typing_extensions import Annotated default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py index bcf1fd3495..a6f29f36d6 100644 --- a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py +++ b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py @@ -1,6 +1,9 @@ +from datetime import timedelta + import mock from flytekit import S3Persistence +from flytekit.configuration import DataConfig, S3Config from flytekit.extras.persistence import s3_awscli @@ -16,14 +19,12 @@ def test_construct_path(): @mock.patch("flytekit.extras.persistence.s3_awscli.S3Persistence._check_binary") -@mock.patch("flytekit.configuration.aws.BACKOFF_SECONDS") @mock.patch("flytekit.extras.persistence.s3_awscli.subprocess") -def test_retries(mock_subprocess, mock_delay, mock_check): - mock_delay.get.return_value = 0 +def test_retries(mock_subprocess, mock_check): mock_subprocess.check_call.side_effect = Exception("test exception (404)") mock_check.return_value = True - proxy = S3Persistence() + proxy = S3Persistence(data_config=DataConfig(s3=S3Config(backoff=timedelta(seconds=0)))) assert proxy.exists("s3://test/fdsa/fdsa") is False assert mock_subprocess.check_call.call_count == 8 @@ -48,7 +49,8 @@ def test_put(mock_exec): proxy = S3Persistence() proxy.put("/test", "s3://my-bucket/k1") mock_exec.assert_called_with( - ["aws", "s3", "cp", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"] + cmd=["aws", "s3", "cp", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], + s3_cfg=S3Config.auto(), ) @@ -57,7 +59,8 @@ def test_put_recursive(mock_exec): proxy = S3Persistence() proxy.put("/test", "s3://my-bucket/k1", True) mock_exec.assert_called_with( - ["aws", "s3", "cp", "--recursive", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"] + cmd=["aws", "s3", "cp", "--recursive", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], + s3_cfg=S3Config.auto(), ) @@ -65,11 +68,13 @@ def test_put_recursive(mock_exec): def test_get(mock_exec): proxy = S3Persistence() proxy.get("s3://my-bucket/k1", "/test") - mock_exec.assert_called_with(["aws", "s3", "cp", "s3://my-bucket/k1", "/test"]) + mock_exec.assert_called_with(cmd=["aws", "s3", "cp", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto()) @mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") def test_get_recursive(mock_exec): proxy = S3Persistence() proxy.get("s3://my-bucket/k1", "/test", True) - mock_exec.assert_called_with(["aws", "s3", "cp", "--recursive", "s3://my-bucket/k1", "/test"]) + mock_exec.assert_called_with( + cmd=["aws", "s3", "cp", "--recursive", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto() + ) diff --git a/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py b/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py index b52978dd58..3db372d165 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py +++ b/tests/flytekit/unit/extras/sqlite3/test_sql_tracker.py @@ -1,7 +1,8 @@ from collections import OrderedDict +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager -from flytekit.core.context_manager import Image, ImageConfig from flytekit.tools.translator import get_serializable from tests.flytekit.unit.extras.sqlite3.test_task import tk as not_tk @@ -12,7 +13,7 @@ def test_sql_lhs(): def test_sql_command(): default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( + serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/remote/test_calling.py b/tests/flytekit/unit/remote/test_calling.py index 07483487ab..00d80464c3 100644 --- a/tests/flytekit/unit/remote/test_calling.py +++ b/tests/flytekit/unit/remote/test_calling.py @@ -4,8 +4,9 @@ import pytest from flytekit import dynamic +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine @@ -19,7 +20,7 @@ from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = SerializationSettings( project="project", domain="domain", version="version", @@ -45,8 +46,7 @@ def sub_wf(a: int, b: str) -> (int, str): return x, d -serialized = OrderedDict() -t1_spec = get_serializable(serialized, serialization_settings, t1) +t1_spec = get_serializable(OrderedDict(), serialization_settings, t1) ft = FlyteTask.promote_from_model(t1_spec.template) @@ -124,10 +124,6 @@ def my_subwf(a: int) -> typing.List[int]: ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, - additional_context={ - "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", - "dynamic_dest_dir": "/User/flyte/workflows", - }, ) ) ) as ctx: diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index cd80d166e2..a6718bd77a 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -1,14 +1,12 @@ -import os - import pytest from mock import MagicMock, patch -from flytekit.configuration import internal +from flytekit.configuration import Config from flytekit.exceptions import user as user_exceptions from flytekit.models import common as common_models from flytekit.models.core.identifier import ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution -from flytekit.remote.remote import FlyteRemote +from flytekit.remote.remote import FlyteRemote, Options CLIENT_METHODS = { ResourceType.WORKFLOW: "list_workflows_paginated", @@ -30,36 +28,28 @@ @patch("flytekit.clients.friendly.SynchronousFlyteClient") -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_remote_fetch_workflow_execution(mock_insecure, mock_url, mock_client_manager): +def test_remote_fetch_workflow_execution(mock_client_manager): admin_workflow_execution = Execution( id=WorkflowExecutionIdentifier("p1", "d1", "n1"), spec=MagicMock(), closure=MagicMock(), ) - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = True mock_client = MagicMock() mock_client.get_execution.return_value = admin_workflow_execution - remote = FlyteRemote.from_config("p1", "d1") + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client flyte_workflow_execution = remote.fetch_workflow_execution(name="n1") assert flyte_workflow_execution.id == admin_workflow_execution.id @patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_underscore_execute_uses_launch_plan_attributes(mock_insecure, mock_url, mock_wf_exec): - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = True +def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - remote = FlyteRemote.from_config("p1", "d1") + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client def local_assertions(*args, **kwargs): @@ -71,32 +61,33 @@ def local_assertions(*args, **kwargs): mock_client.create_execution.side_effect = local_assertions mock_entity = MagicMock() + options = Options( + labels=common_models.Labels({"a": "my_label_value"}), + annotations=common_models.Annotations({"b": "my_annotation_value"}), + auth_role=common_models.AuthRole(kubernetes_service_account="svc"), + ) remote._execute( mock_entity, inputs={}, project="proj", domain="dev", - labels=common_models.Labels({"a": "my_label_value"}), - annotations=common_models.Annotations({"b": "my_annotation_value"}), - auth_role=common_models.AuthRole(kubernetes_service_account="svc"), + options=options, ) @patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -@patch("flytekit.configuration.auth.ASSUMABLE_IAM_ROLE") -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_underscore_execute_fall_back_remote_attributes(mock_insecure, mock_url, mock_iam_role, mock_wf_exec): - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = True - mock_iam_role.get.return_value = "iam:some:role" +def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - remote = FlyteRemote.from_config("p1", "d1") + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client + options = Options( + auth_role=common_models.AuthRole(assumable_iam_role="iam:some:role"), + ) + def local_assertions(*args, **kwargs): execution_spec = args[3] assert execution_spec.auth_role.assumable_iam_role == "iam:some:role" @@ -110,19 +101,18 @@ def local_assertions(*args, **kwargs): inputs={}, project="proj", domain="dev", + options=options, ) @patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_execute_with_wrong_input_key(mock_insecure, mock_url, mock_wf_exec): - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = True +def test_execute_with_wrong_input_key(mock_wf_exec): + # mock_url.get.return_value = "localhost" + # mock_insecure.get.return_value = True mock_wf_exec.return_value = True mock_client = MagicMock() - remote = FlyteRemote.from_config("p1", "d1") + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client mock_entity = MagicMock() @@ -137,46 +127,22 @@ def test_execute_with_wrong_input_key(mock_insecure, mock_url, mock_wf_exec): ) -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_form_config(mock_insecure, mock_url): - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = True - - FlyteRemote.from_config("p1", "d1") - assert ".flyte/config" in os.environ[internal.CONFIGURATION_PATH.env_var] - remote = FlyteRemote.from_config("p1", "d1", "fake_config") - assert "fake_config" in os.environ[internal.CONFIGURATION_PATH.env_var] - - assert remote._flyte_admin_url == "localhost" - assert remote._insecure is True +def test_form_config(): + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") assert remote.default_project == "p1" assert remote.default_domain == "d1" -@patch("flytekit.clients.raw._ssl_channel_credentials") -@patch("flytekit.clients.raw._secure_channel") -@patch("flytekit.configuration.platform.URL") -@patch("flytekit.configuration.platform.INSECURE") -def test_explicit_grpc_channel_credentials(mock_insecure, mock_url, mock_secure_channel, mock_ssl_channel_credentials): - mock_url.get.return_value = "localhost" - mock_insecure.get.return_value = False - - # Default mode, no explicit channel credentials - mock_ssl_channel_credentials.reset_mock() - _ = FlyteRemote.from_config("project", "domain") - - assert mock_ssl_channel_credentials.called - - mock_secure_channel.reset_mock() - mock_ssl_channel_credentials.reset_mock() - - # Explicit channel credentials - from grpc import ssl_channel_credentials - - credentials = ssl_channel_credentials(b"TEST CERTIFICATE") - - _ = FlyteRemote.from_config("project", "domain", grpc_credentials=credentials) - assert mock_secure_channel.called - assert mock_secure_channel.call_args[0][1] == credentials - assert mock_ssl_channel_credentials.call_count == 1 +@patch("flytekit.remote.remote.SynchronousFlyteClient") +def test_passing_of_kwargs(mock_client): + additional_args = { + "credentials": 1, + "options": 2, + "private_key": 3, + "compression": 4, + "root_certificates": 5, + "certificate_chain": 6, + } + FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args) + assert mock_client.called + assert mock_client.call_args[1] == additional_args diff --git a/tests/flytekit/unit/remote/test_wrapper_classes.py b/tests/flytekit/unit/remote/test_wrapper_classes.py index f229489b14..4466d5ddc2 100644 --- a/tests/flytekit/unit/remote/test_wrapper_classes.py +++ b/tests/flytekit/unit/remote/test_wrapper_classes.py @@ -3,9 +3,10 @@ import pytest +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.condition import conditional -from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -13,7 +14,7 @@ from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index e21353c38e..89a1bb8478 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -1,10 +1,11 @@ import typing from collections import OrderedDict +import flytekit.configuration from flytekit import ContainerTask, Resources +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig from flytekit.core import context_manager from flytekit.core.base_task import kwtypes -from flytekit.core.context_manager import FastSerializationSettings, Image, ImageConfig from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.reference_entity import ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask, task @@ -13,7 +14,7 @@ from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") -serialization_settings = context_manager.SerializationSettings( +serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", diff --git a/tests/flytekit_compatibility/test_structured_dataset.py b/tests/flytekit_compatibility/test_structured_dataset.py index f935305210..a8eecd356a 100644 --- a/tests/flytekit_compatibility/test_structured_dataset.py +++ b/tests/flytekit_compatibility/test_structured_dataset.py @@ -1,12 +1,12 @@ import pandas as pd -from flytekit.configuration.sdk import USE_STRUCTURED_DATASET +from flytekit.configuration import internal from flytekit.core.type_engine import TypeEngine def test_pandas_is_schema_with_flag(): # This test can only be run iff USE_STRUCTURED_DATASET is not set - assert USE_STRUCTURED_DATASET.get() is False + assert not internal.LocalSDK.USE_STRUCTURED_DATASET.read() lt = TypeEngine.to_literal_type(pd.DataFrame) assert lt.schema is not None