From 0918267b817531d66f4789f236443b32c9fe51ca Mon Sep 17 00:00:00 2001 From: Romain Date: Thu, 24 Oct 2024 11:25:10 -0700 Subject: [PATCH 01/22] Prepare release 2.12.27 (#2116) --- metaflow/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/version.py b/metaflow/version.py index a72c43bf842..9d7b3279d19 100644 --- a/metaflow/version.py +++ b/metaflow/version.py @@ -1 +1 @@ -metaflow_version = "2.12.26" +metaflow_version = "2.12.27" From a37555b6aa673d495390b28b34a36672b024381e Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Fri, 25 Oct 2024 09:13:48 -0700 Subject: [PATCH 02/22] [kubernetes][jobsets] dont set replicas to 0 (#2118) --- metaflow/plugins/kubernetes/kubernetes_jobsets.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index 7a0d59ce77a..cf6c6affe2b 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -332,11 +332,8 @@ def kill(self): name=self._name, ) - # Suspend the jobset and set the replica's to Zero. - # + # Suspend the jobset obj["spec"]["suspend"] = True - for replicated_job in obj["spec"]["replicatedJobs"]: - replicated_job["replicas"] = 0 api_instance.replace_namespaced_custom_object( group=self._group, From f222585240d07f802bd592d909e38fa0098adb94 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 31 Oct 2024 02:20:15 -0700 Subject: [PATCH 03/22] Refactor system event logger and monitor, add new metrics (#2065) * Emit graph info to event logger, add runtime start metric * Remove graph_info logging from resume * Simplify run and resume metrics * Remove logger update context * Refactor monitor and logger to not use update_context * Address comments --- metaflow/cli.py | 27 +++++++++++++++++++++++++++ metaflow/system/system_logger.py | 20 +------------------- metaflow/system/system_monitor.py | 24 ------------------------ metaflow/task.py | 11 ++++------- 4 files changed, 32 insertions(+), 50 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 64800b189b7..a318b84a3ec 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -716,6 +716,20 @@ def resume( if runtime.should_skip_clone_only_execution(): return + current._update_env( + { + "run_id": runtime.run_id, + } + ) + _system_logger.log_event( + level="info", + module="metaflow.resume", + name="start", + payload={ + "msg": "Resuming run", + }, + ) + with runtime.run_heartbeat(): if clone_only: runtime.clone_original_run() @@ -775,6 +789,19 @@ def run( write_file(run_id_file, runtime.run_id) obj.flow._set_constants(obj.graph, kwargs) + current._update_env( + { + "run_id": runtime.run_id, + } + ) + _system_logger.log_event( + level="info", + module="metaflow.run", + name="start", + payload={ + "msg": "Starting run", + }, + ) runtime.print_workflow_info() runtime.persist_constants() diff --git a/metaflow/system/system_logger.py b/metaflow/system/system_logger.py index d3d3e1e5a3f..b065f25b488 100644 --- a/metaflow/system/system_logger.py +++ b/metaflow/system/system_logger.py @@ -7,26 +7,11 @@ class SystemLogger(object): def __init__(self): self._logger = None self._flow_name = None - self._context = {} - self._is_context_updated = False def __del__(self): if self._flow_name == "not_a_real_flow": self.logger.terminate() - def update_context(self, context: Dict[str, Any]): - """ - Update the global context maintained by the system logger. - - Parameters - ---------- - context : Dict[str, Any] - A dictionary containing the context to update. - - """ - self._is_context_updated = True - self._context.update(context) - def init_system_logger( self, flow_name: str, logger: "metaflow.event_logger.NullEventLogger" ): @@ -71,7 +56,7 @@ def _debug(msg: str): "false", "", ): - print("system monitor: %s" % msg, file=sys.stderr) + print("system logger: %s" % msg, file=sys.stderr) def log_event( self, level: str, module: str, name: str, payload: Optional[Any] = None @@ -96,8 +81,5 @@ def log_event( "module": module, "name": name, "payload": payload if payload is not None else {}, - "context": self._context, - "is_context_updated": self._is_context_updated, } ) - self._is_context_updated = False diff --git a/metaflow/system/system_monitor.py b/metaflow/system/system_monitor.py index 3701607f34d..721f54e38fa 100644 --- a/metaflow/system/system_monitor.py +++ b/metaflow/system/system_monitor.py @@ -9,35 +9,11 @@ class SystemMonitor(object): def __init__(self): self._monitor = None self._flow_name = None - self._context = {} def __del__(self): if self._flow_name == "not_a_real_flow": self.monitor.terminate() - def update_context(self, context: Dict[str, Any]): - """ - Update the global context maintained by the system monitor. - - Parameters - ---------- - context : Dict[str, Any] - A dictionary containing the context to update. - - """ - from metaflow.sidecar import Message, MessageTypes - - self._context.update(context) - self.monitor.send( - Message( - MessageTypes.MUST_SEND, - { - "is_context_updated": True, - **self._context, - }, - ) - ) - def init_system_monitor( self, flow_name: str, monitor: "metaflow.monitor.NullMonitor" ): diff --git a/metaflow/task.py b/metaflow/task.py index bccaf47c668..bba15c45471 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -306,8 +306,6 @@ def clone_only( "origin_run_id": origin_run_id, "origin_task_id": origin_task_id, } - _system_logger.update_context(task_payload) - _system_monitor.update_context(task_payload) msg = "Cloning task from {}/{}/{}/{} to {}/{}/{}/{}".format( self.flow.name, @@ -545,9 +543,6 @@ def run_step( "project_flow_name": current.get("project_flow_name"), "trace_id": trace_id or None, } - - _system_logger.update_context(task_payload) - _system_monitor.update_context(task_payload) start = time.time() self.metadata.start_task_heartbeat(self.flow.name, run_id, step_name, task_id) with self.monitor.measure("metaflow.task.duration"): @@ -592,7 +587,8 @@ def run_step( { "parameter_names": self._init_parameters( inputs[0], passdown=True - ) + ), + "graph_info": self.flow._graph_info, } ) else: @@ -616,7 +612,8 @@ def run_step( { "parameter_names": self._init_parameters( inputs[0], passdown=False - ) + ), + "graph_info": self.flow._graph_info, } ) From 6b6bc41e810bed44e9c6b923be4700195d975816 Mon Sep 17 00:00:00 2001 From: Darin Date: Thu, 31 Oct 2024 17:39:17 -0700 Subject: [PATCH 04/22] prepare release 2.12.28 (#2125) --- metaflow/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/version.py b/metaflow/version.py index 9d7b3279d19..11c1fdd73ce 100644 --- a/metaflow/version.py +++ b/metaflow/version.py @@ -1 +1 @@ -metaflow_version = "2.12.27" +metaflow_version = "2.12.28" From b4c5f292b6b1145c807da347068cecfe92ecf6cc Mon Sep 17 00:00:00 2001 From: Chaoying Wang Date: Fri, 1 Nov 2024 02:23:45 -0700 Subject: [PATCH 05/22] ping r version to 4.4.1 to fix test (#2127) --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3a5bc9f857f..8ae808b7945 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,7 +61,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-20.04] - ver: ['4.4'] + ver: ['4.4.1'] steps: - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0 From 5450af5d6b3e537a5c52d60035a0ca653b07f48e Mon Sep 17 00:00:00 2001 From: madhur-ob <155637867+madhur-ob@users.noreply.github.com> Date: Wed, 6 Nov 2024 22:43:31 +0530 Subject: [PATCH 06/22] dont support slurm with airflow and sfn (#2134) --- metaflow/plugins/airflow/airflow_cli.py | 5 +++++ metaflow/plugins/aws/step_functions/step_functions_cli.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/metaflow/plugins/airflow/airflow_cli.py b/metaflow/plugins/airflow/airflow_cli.py index 7c4c330e030..b80d82fbcee 100644 --- a/metaflow/plugins/airflow/airflow_cli.py +++ b/metaflow/plugins/airflow/airflow_cli.py @@ -389,6 +389,11 @@ def _validate_workflow(flow, graph, flow_datastore, metadata, workflow_timeout): "Step *%s* is marked for execution on AWS Batch with Airflow which isn't currently supported." % node.name ) + if any([d.name == "slurm" for d in node.decorators]): + raise NotSupportedException( + "Step *%s* is marked for execution on Slurm with Airflow which isn't currently supported." + % node.name + ) SUPPORTED_DATASTORES = ("azure", "s3", "gs") if flow_datastore.TYPE not in SUPPORTED_DATASTORES: raise AirflowException( diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 0502c092245..efa4e7f35f4 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -154,6 +154,13 @@ def create( use_distributed_map=False, deployer_attribute_file=None, ): + for node in obj.graph: + if any([d.name == "slurm" for d in node.decorators]): + raise MetaflowException( + "Step *%s* is marked for execution on Slurm with AWS Step Functions which isn't currently supported." + % node.name + ) + validate_tags(tags) if deployer_attribute_file: From 4dc910a7e4b1db9c7607fef771470dcb1ceae30a Mon Sep 17 00:00:00 2001 From: madhur-ob <155637867+madhur-ob@users.noreply.github.com> Date: Wed, 6 Nov 2024 22:51:57 +0530 Subject: [PATCH 07/22] use METAFLOW_CARD_LOCALROOT for card server (#2121) * use the correct path * suggested changes --- metaflow/plugins/cards/card_server.py | 53 ++++++++++++++++++++------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/metaflow/plugins/cards/card_server.py b/metaflow/plugins/cards/card_server.py index 7b5008410ba..d0050c6559e 100644 --- a/metaflow/plugins/cards/card_server.py +++ b/metaflow/plugins/cards/card_server.py @@ -20,14 +20,9 @@ class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): from .card_client import CardContainer from .exception import CardNotPresentException from .card_resolver import resolve_paths_from_task -from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow import namespace -from metaflow.exception import ( - CommandException, - MetaflowNotFound, - MetaflowNamespaceMismatch, -) - +from metaflow.exception import MetaflowNotFound +from metaflow.plugins.datastores.local_storage import LocalStorage VIEWER_PATH = os.path.join( os.path.dirname(os.path.abspath(__file__)), "card_viewer", "viewer.html" @@ -50,18 +45,48 @@ class RunWatcher(Thread): def __init__(self, flow_name, connection: Connection): super().__init__() - self._watch_file = os.path.join( - os.getcwd(), DATASTORE_LOCAL_DIR, flow_name, "latest_run" - ) - self._current_run_id = self.get_run_id() self.daemon = True self._connection = connection + self._flow_name = flow_name + + self._watch_file = self._initialize_watch_file() + if self._watch_file is None: + _ClickLogger( + "Warning: Could not initialize watch file location.", fg="yellow" + ) + + self._current_run_id = self.get_run_id() + + def _initialize_watch_file(self): + local_root = LocalStorage.datastore_root + if local_root is None: + local_root = LocalStorage.get_datastore_root_from_config( + lambda _: None, create_on_absent=False + ) + + return ( + os.path.join(local_root, self._flow_name, "latest_run") + if local_root + else None + ) def get_run_id(self): - if not os.path.exists(self._watch_file): + # Try to reinitialize watch file if needed + if not self._watch_file: + self._watch_file = self._initialize_watch_file() + + # Early return if watch file is still None or doesn't exist + if not (self._watch_file and os.path.exists(self._watch_file)): + return None + + try: + with open(self._watch_file, "r") as f: + return f.read().strip() + except (IOError, OSError) as e: + _ClickLogger( + "Warning: Could not read run ID from watch file: %s" % e, fg="yellow" + ) return None - with open(self._watch_file, "r") as f: - return f.read().strip() def watch(self): while True: From 65fd88891dade200f5697f17e57787638ee97a98 Mon Sep 17 00:00:00 2001 From: madhur-ob <155637867+madhur-ob@users.noreply.github.com> Date: Thu, 7 Nov 2024 00:26:09 +0530 Subject: [PATCH 08/22] better error message with dump (#2130) * better error message with dump * only print on final failure --- metaflow/cli.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index a318b84a3ec..1fc6a14953f 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -282,21 +282,31 @@ def dump(obj, input_path, private=None, max_value_size=None, include=None, file= else: ds_list = list(datastore_set) # get all tasks + tasks_processed = False for ds in ds_list: - echo( - "Dumping output of run_id=*{run_id}* " - "step=*{step}* task_id=*{task_id}*".format( - run_id=ds.run_id, step=ds.step_name, task_id=ds.task_id - ), - fg="magenta", - ) - - if file is None: - echo_always( - ds.format(**kwargs), highlight="green", highlight_bold=False, err=False + if ds is not None: + tasks_processed = True + echo( + "Dumping output of run_id=*{run_id}* " + "step=*{step}* task_id=*{task_id}*".format( + run_id=ds.run_id, step=ds.step_name, task_id=ds.task_id + ), + fg="magenta", ) - else: - output[ds.pathspec] = ds.to_dict(**kwargs) + + if file is None: + echo_always( + ds.format(**kwargs), + highlight="green", + highlight_bold=False, + err=False, + ) + else: + output[ds.pathspec] = ds.to_dict(**kwargs) + + if not tasks_processed: + echo(f"No task(s) found for pathspec {input_path}", fg="red") + return if file is not None: with open(file, "wb") as f: From 81eb31b52e08731eca84664fc51e150560fc22ef Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 7 Nov 2024 00:29:38 -0800 Subject: [PATCH 09/22] Update escape hatch sys path when used in jupyter notebook (#2132) --- metaflow/plugins/env_escape/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metaflow/plugins/env_escape/__init__.py b/metaflow/plugins/env_escape/__init__.py index bba84fc6ce7..65f78029ff4 100644 --- a/metaflow/plugins/env_escape/__init__.py +++ b/metaflow/plugins/env_escape/__init__.py @@ -124,9 +124,9 @@ def load(): cur_path = os.path.dirname(__file__) sys.path = [p for p in old_paths if p != cur_path] # Handle special case where we launch a shell (including with a command) - # and we are in the CWD (searched if '' is the first element of sys.path) - if cur_path == os.getcwd() and sys.path[0] == '': - sys.path = sys.path[1:] + # and we are in the CWD (searched if '' is present in sys.path) + if cur_path == os.getcwd() and '' in sys.path: + sys.path.remove("") # Remove the module (this file) to reload it properly. Do *NOT* update sys.modules but # modify directly since it may be referenced elsewhere From dfc4a7178b5458b0dba71646487ed26e31877abf Mon Sep 17 00:00:00 2001 From: madhur-ob <155637867+madhur-ob@users.noreply.github.com> Date: Fri, 8 Nov 2024 02:16:22 +0530 Subject: [PATCH 10/22] deployer inheritance (#2135) * Move DeployedFlow and Triggered run to a more inheritance model. The _enrich_* methods were making it very hard to generate proper stubs. This will hopefully improve that. * Fix circular dependency and address comments Some doc cleanup as well. * Removed extraneous code * Forgot file * Update NBDeployer to forward to underlying deployer * Fix circular deps and other crap * Added stub-gen for deployer (try 2) * Fix stub test * Fix dataclients * fix docstrings for deployer (#2119) * Renamed metadata directory to metadata_provider (and related changes) There was a clash between the `metadata` function and the `metadata` module which caused issues with stubs (at least the new way of generating) * Forgot one * Typo * Make nbdocs happy * add docstrings for injected methods --------- Co-authored-by: Romain Cledat --- metaflow/__init__.py | 5 +- metaflow/client/core.py | 4 +- metaflow/clone_util.py | 2 +- metaflow/cmd/develop/stub_generator.py | 856 +++++++++++++----- metaflow/datastore/task_datastore.py | 2 +- metaflow/extension_support/plugins.py | 1 + metaflow/flowspec.py | 4 +- metaflow/includefile.py | 22 +- .../__init__.py | 0 .../heartbeat.py | 0 .../metadata.py | 0 .../{metadata => metadata_provider}/util.py | 0 metaflow/metaflow_config.py | 4 + metaflow/metaflow_current.py | 2 +- metaflow/parameters.py | 3 + metaflow/plugins/__init__.py | 15 +- metaflow/plugins/airflow/airflow_decorator.py | 2 +- .../plugins/argo/argo_workflows_decorator.py | 2 +- .../plugins/argo/argo_workflows_deployer.py | 340 ++----- .../argo/argo_workflows_deployer_objects.py | 381 ++++++++ metaflow/plugins/aws/batch/batch_cli.py | 2 +- metaflow/plugins/aws/batch/batch_decorator.py | 4 +- .../step_functions_decorator.py | 2 +- .../step_functions/step_functions_deployer.py | 289 ++---- .../step_functions_deployer_objects.py | 236 +++++ metaflow/plugins/azure/includefile_support.py | 2 + metaflow/plugins/cards/card_cli.py | 5 +- .../plugins/cards/card_modules/components.py | 18 +- metaflow/plugins/datatools/local.py | 2 + metaflow/plugins/datatools/s3/s3.py | 2 + metaflow/plugins/gcp/includefile_support.py | 3 + metaflow/plugins/kubernetes/kubernetes_cli.py | 2 +- .../kubernetes/kubernetes_decorator.py | 9 +- .../__init__.py | 0 .../{metadata => metadata_providers}/local.py | 4 +- .../service.py | 4 +- metaflow/plugins/parallel_decorator.py | 2 +- metaflow/plugins/pypi/conda_decorator.py | 2 +- .../test_unbounded_foreach_decorator.py | 2 +- metaflow/runner/click_api.py | 4 + metaflow/runner/deployer.py | 408 +++------ metaflow/runner/deployer_impl.py | 167 ++++ metaflow/runner/metaflow_runner.py | 19 +- metaflow/runner/nbdeploy.py | 25 +- metaflow/runner/nbrun.py | 6 +- metaflow/runner/utils.py | 63 +- metaflow/runtime.py | 2 +- metaflow/task.py | 2 +- stubs/test/test_stubs.yml | 68 +- 49 files changed, 1874 insertions(+), 1125 deletions(-) rename metaflow/{metadata => metadata_provider}/__init__.py (100%) rename metaflow/{metadata => metadata_provider}/heartbeat.py (100%) rename metaflow/{metadata => metadata_provider}/metadata.py (100%) rename metaflow/{metadata => metadata_provider}/util.py (100%) create mode 100644 metaflow/plugins/argo/argo_workflows_deployer_objects.py create mode 100644 metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py rename metaflow/plugins/{metadata => metadata_providers}/__init__.py (100%) rename metaflow/plugins/{metadata => metadata_providers}/local.py (99%) rename metaflow/plugins/{metadata => metadata_providers}/service.py (99%) create mode 100644 metaflow/runner/deployer_impl.py diff --git a/metaflow/__init__.py b/metaflow/__init__.py index c901e81c38e..409922a49d6 100644 --- a/metaflow/__init__.py +++ b/metaflow/__init__.py @@ -101,9 +101,7 @@ class and related decorators. # Flow spec from .flowspec import FlowSpec -from .parameters import Parameter, JSONTypeClass - -JSONType = JSONTypeClass() +from .parameters import Parameter, JSONTypeClass, JSONType # data layer # For historical reasons, we make metaflow.plugins.datatools accessible as @@ -149,6 +147,7 @@ class and related decorators. from .runner.metaflow_runner import Runner from .runner.nbrun import NBRunner from .runner.deployer import Deployer + from .runner.deployer import DeployedFlow from .runner.nbdeploy import NBDeployer __ext_tl_modules__ = [] diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 9534ffcca2c..87b6a88c37c 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -39,7 +39,7 @@ from .filecache import FileCache if TYPE_CHECKING: - from metaflow.metadata import MetadataProvider + from metaflow.metadata_provider import MetadataProvider try: # python2 @@ -143,7 +143,7 @@ def default_metadata() -> str: if default: current_metadata = default[0] else: - from metaflow.plugins.metadata import LocalMetadataProvider + from metaflow.plugins.metadata_providers import LocalMetadataProvider current_metadata = LocalMetadataProvider return get_metadata() diff --git a/metaflow/clone_util.py b/metaflow/clone_util.py index 124efa33a57..5b092474333 100644 --- a/metaflow/clone_util.py +++ b/metaflow/clone_util.py @@ -1,5 +1,5 @@ import time -from .metadata import MetaDatum +from .metadata_provider import MetaDatum def clone_task_helper( diff --git a/metaflow/cmd/develop/stub_generator.py b/metaflow/cmd/develop/stub_generator.py index 0e6e9f86ae6..a43d3e72d87 100644 --- a/metaflow/cmd/develop/stub_generator.py +++ b/metaflow/cmd/develop/stub_generator.py @@ -31,13 +31,17 @@ from metaflow.debug import debug from metaflow.decorators import Decorator, FlowDecorator from metaflow.extension_support import get_aliased_modules -from metaflow.graph import deindent_docstring +from metaflow.metaflow_current import Current from metaflow.metaflow_version import get_version +from metaflow.runner.deployer import DeployedFlow, Deployer, TriggeredRun +from metaflow.runner.deployer_impl import DeployerImpl TAB = " " METAFLOW_CURRENT_MODULE_NAME = "metaflow.metaflow_current" +METAFLOW_DEPLOYER_MODULE_NAME = "metaflow.runner.deployer" param_section_header = re.compile(r"Parameters\s*\n----------\s*\n", flags=re.M) +return_section_header = re.compile(r"Returns\s*\n-------\s*\n", flags=re.M) add_to_current_header = re.compile( r"MF Add To Current\s*\n-----------------\s*\n", flags=re.M ) @@ -57,6 +61,20 @@ ] +# Object that has start() and end() like a Match object to make the code simpler when +# we are parsing different sections of doc +class StartEnd: + def __init__(self, start: int, end: int): + self._start = start + self._end = end + + def start(self): + return self._start + + def end(self): + return self._end + + def type_var_to_str(t: TypeVar) -> str: bound_name = None if t.__bound__ is not None: @@ -92,6 +110,131 @@ def descend_object(object: str, options: Iterable[str]): return False +def parse_params_from_doc(doc: str) -> Tuple[List[inspect.Parameter], bool]: + parameters = [] + no_arg_version = True + for line in doc.splitlines(): + if non_indented_line.match(line): + match = param_name_type.match(line) + arg_name = type_name = is_optional = default = None + default_set = False + if match is not None: + arg_name = match.group("name") + type_name = match.group("type") + if type_name is not None: + type_detail = type_annotations.match(type_name) + if type_detail is not None: + type_name = type_detail.group("type") + is_optional = type_detail.group("optional") is not None + default = type_detail.group("default") + if default: + default_set = True + try: + default = eval(default) + except: + pass + try: + type_name = eval(type_name) + except: + pass + parameters.append( + inspect.Parameter( + name=arg_name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=( + default + if default_set + else None if is_optional else inspect.Parameter.empty + ), + annotation=(Optional[type_name] if is_optional else type_name), + ) + ) + if not default_set: + # If we don't have a default set for any parameter, we can't + # have a no-arg version since the function would be incomplete + no_arg_version = False + return parameters, no_arg_version + + +def split_docs( + raw_doc: str, boundaries: List[Tuple[str, Union[StartEnd, re.Match]]] +) -> Dict[str, str]: + docs = dict() + boundaries.sort(key=lambda x: x[1].start()) + + section_start = 0 + for idx in range(1, len(boundaries)): + docs[boundaries[idx - 1][0]] = raw_doc[ + section_start : boundaries[idx][1].start() + ] + section_start = boundaries[idx][1].end() + docs[boundaries[-1][0]] = raw_doc[section_start:] + return docs + + +def parse_add_to_docs( + raw_doc: str, +) -> Dict[str, Union[Tuple[inspect.Signature, str], str]]: + prop = None + return_type = None + property_indent = None + doc = [] + add_to_docs = dict() # type: Dict[str, Union[str, Tuple[inspect.Signature, str]]] + + def _add(): + if prop: + add_to_docs[prop] = ( + inspect.Signature( + [ + inspect.Parameter( + "self", inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + ], + return_annotation=return_type, + ), + "\n".join(doc), + ) + + for line in raw_doc.splitlines(): + # Parse stanzas that look like the following: + # -> type + # indented doc string + if property_indent is not None and ( + line.startswith(property_indent + " ") or line.strip() == "" + ): + offset = len(property_indent) + if line.lstrip().startswith("@@ "): + line = line.replace("@@ ", "") + doc.append(line[offset:].rstrip()) + else: + if line.strip() == 0: + continue + if prop: + # Ends a property stanza + _add() + # Now start a new one + line = line.rstrip() + property_indent = line[: len(line) - len(line.lstrip())] + # Either this has a -> to denote a property or it is a pure name + # to denote a reference to a function (starting with #) + line = line.lstrip() + if line.startswith("#"): + # The name of the function is the last part like metaflow.deployer.run + add_to_docs[line.split(".")[-1]] = line[1:] + continue + # This is a line so we split it using "->" + prop, return_type = line.split("->") + prop = prop.strip() + return_type = return_type.strip() + doc = [] + _add() + return add_to_docs + + +def add_indent(indentation: str, text: str) -> str: + return "\n".join([indentation + line for line in text.splitlines()]) + + class StubGenerator: """ This class takes the name of a library as input and a directory as output. @@ -121,11 +264,18 @@ def __init__(self, output_dir: str, include_generated_for: bool = True): os.environ["METAFLOW_STUBGEN"] = "1" self._write_generated_for = include_generated_for - self._pending_modules = ["metaflow"] # type: List[str] - self._pending_modules.extend(get_aliased_modules()) + # First element is the name it should be installed in (alias) and second is the + # actual module name + self._pending_modules = [ + ("metaflow", "metaflow") + ] # type: List[Tuple[str, str]] self._root_module = "metaflow." self._safe_modules = ["metaflow.", "metaflow_extensions."] + self._pending_modules.extend( + (self._get_module_name_alias(x), x) for x in get_aliased_modules() + ) + # We exclude some modules to not create a bunch of random non-user facing # .pyi files. self._exclude_modules = set( @@ -151,7 +301,7 @@ def __init__(self, output_dir: str, include_generated_for: bool = True): "metaflow.package", "metaflow.plugins.datastores", "metaflow.plugins.env_escape", - "metaflow.plugins.metadata", + "metaflow.plugins.metadata_providers", "metaflow.procpoll.py", "metaflow.R", "metaflow.runtime", @@ -163,9 +313,16 @@ def __init__(self, output_dir: str, include_generated_for: bool = True): "metaflow._vendor", ] ) + self._done_modules = set() # type: Set[str] self._output_dir = output_dir self._mf_version = get_version() + + # Contains the names of the methods that are injected in Deployer + self._deployer_injected_methods = ( + {} + ) # type: Dict[str, Dict[str, Union[Tuple[str, str], str]]] + # Contains information to add to the Current object (injected by decorators) self._addl_current = ( dict() ) # type: Dict[str, Dict[str, Tuple[inspect.Signature, str]]] @@ -184,6 +341,7 @@ def _reset(self): self._typevars = dict() # type: Dict[str, Union[TypeVar, type]] # Current objects in the file being processed self._current_objects = {} # type: Dict[str, Any] + self._current_references = [] # type: List[str] # Current stubs in the file being processed self._stubs = [] # type: List[str] @@ -192,26 +350,78 @@ def _reset(self): # the "globals()" self._current_parent_module = None # type: Optional[ModuleType] - def _get_module(self, name): - debug.stubgen_exec("Analyzing module %s ..." % name) + def _get_module_name_alias(self, module_name): + if any( + module_name.startswith(x) for x in self._safe_modules + ) and not module_name.startswith(self._root_module): + return self._root_module + ".".join( + ["mf_extensions", *module_name.split(".")[1:]] + ) + return module_name + + def _get_relative_import( + self, new_module_name, cur_module_name, is_init_module=False + ): + new_components = new_module_name.split(".") + cur_components = cur_module_name.split(".") + init_module_count = 1 if is_init_module else 0 + common_idx = 0 + max_idx = min(len(new_components), len(cur_components)) + while ( + common_idx < max_idx + and new_components[common_idx] == cur_components[common_idx] + ): + common_idx += 1 + # current: a.b and parent: a.b.e.d -> from .e.d import + # current: a.b.c.d and parent: a.b.e.f -> from ...e.f import + return "." * (len(cur_components) - common_idx + init_module_count) + ".".join( + new_components[common_idx:] + ) + + def _get_module(self, alias, name): + debug.stubgen_exec("Analyzing module %s (aliased at %s)..." % (name, alias)) self._current_module = importlib.import_module(name) - self._current_module_name = name + self._current_module_name = alias for objname, obj in self._current_module.__dict__.items(): + if objname == "_addl_stubgen_modules": + debug.stubgen_exec( + "Adding modules %s from _addl_stubgen_modules" % str(obj) + ) + self._pending_modules.extend( + (self._get_module_name_alias(m), m) for m in obj + ) + continue if objname.startswith("_"): debug.stubgen_exec( "Skipping object because it starts with _ %s" % objname ) continue if inspect.ismodule(obj): - # Only consider modules that are part of the root module + # Only consider modules that are safe modules if ( - obj.__name__.startswith(self._root_module) + any(obj.__name__.startswith(m) for m in self._safe_modules) and not obj.__name__ in self._exclude_modules ): debug.stubgen_exec( "Adding child module %s to process" % obj.__name__ ) - self._pending_modules.append(obj.__name__) + + new_module_alias = self._get_module_name_alias(obj.__name__) + self._pending_modules.append((new_module_alias, obj.__name__)) + + new_parent, new_name = new_module_alias.rsplit(".", 1) + self._current_references.append( + "from %s import %s as %s" + % ( + self._get_relative_import( + new_parent, + alias, + hasattr(self._current_module, "__path__"), + ), + new_name, + objname, + ) + ) else: debug.stubgen_exec("Skipping child module %s" % obj.__name__) else: @@ -221,8 +431,10 @@ def _get_module(self, name): # we could be more specific but good enough for now) for root module. # We also include the step decorator (it's from metaflow.decorators # which is typically excluded) - # - otherwise, anything that is in safe_modules. Note this may include - # a bit much (all the imports) + # - Stuff that is defined in this module itself + # - a reference to anything in the modules we will process later + # (so we don't duplicate a ton of times) + if ( parent_module is None or ( @@ -232,43 +444,44 @@ def _get_module(self, name): or obj == step ) ) - or ( - not any( - [ - parent_module.__name__.startswith(p) - for p in self._exclude_modules - ] - ) - and any( - [ - parent_module.__name__.startswith(p) - for p in self._safe_modules - ] - ) - ) + or parent_module.__name__ == name ): debug.stubgen_exec("Adding object %s to process" % objname) self._current_objects[objname] = obj - else: - debug.stubgen_exec("Skipping object %s" % objname) - # We also include the module to process if it is part of root_module - if ( - parent_module is not None - and not any( - [ - parent_module.__name__.startswith(d) - for d in self._exclude_modules - ] - ) - and parent_module.__name__.startswith(self._root_module) + + elif not any( + [ + parent_module.__name__.startswith(p) + for p in self._exclude_modules + ] + ) and any( + [parent_module.__name__.startswith(p) for p in self._safe_modules] ): + parent_alias = self._get_module_name_alias(parent_module.__name__) + + relative_import = self._get_relative_import( + parent_alias, alias, hasattr(self._current_module, "__path__") + ) + debug.stubgen_exec( - "Adding module of child object %s to process" - % parent_module.__name__, + "Adding reference %s and adding module %s as %s" + % (objname, parent_module.__name__, parent_alias) + ) + obj_import_name = getattr(obj, "__name__", objname) + if obj_import_name == "": + # We have one case of this + obj_import_name = objname + self._current_references.append( + "from %s import %s as %s" + % (relative_import, obj_import_name, objname) ) - self._pending_modules.append(parent_module.__name__) + self._pending_modules.append((parent_alias, parent_module.__name__)) + else: + debug.stubgen_exec("Skipping object %s" % objname) - def _get_element_name_with_module(self, element: Union[TypeVar, type, Any]) -> str: + def _get_element_name_with_module( + self, element: Union[TypeVar, type, Any], force_import=False + ) -> str: # The element can be a string, for example "def f() -> 'SameClass':..." def _add_to_import(name): if name != self._current_module_name: @@ -292,6 +505,9 @@ def _add_to_typing_check(name, is_module=False): self._typing_imports.add(splits[0]) if isinstance(element, str): + # Special case for self referential things (particularly in a class) + if element == self._current_name: + return '"%s"' % element # We first try to eval the annotation because with the annotations future # it is always a string try: @@ -309,6 +525,9 @@ def _add_to_typing_check(name, is_module=False): pass if isinstance(element, str): + # If we are in our "safe" modules, make sure we alias properly + if any(element.startswith(x) for x in self._safe_modules): + element = self._get_module_name_alias(element) _add_to_typing_check(element) return '"%s"' % element # 3.10+ has NewType as a class but not before so hack around to check for NewType @@ -328,9 +547,12 @@ def _add_to_typing_check(name, is_module=False): return "None" return element.__name__ - _add_to_typing_check(module.__name__, is_module=True) - if module.__name__ != self._current_module_name: - return "{0}.{1}".format(module.__name__, element.__name__) + module_name = self._get_module_name_alias(module.__name__) + if force_import: + _add_to_import(module_name.split(".")[0]) + _add_to_typing_check(module_name, is_module=True) + if module_name != self._current_module_name: + return "{0}.{1}".format(module_name, element.__name__) else: return element.__name__ elif isinstance(element, type(Ellipsis)): @@ -364,7 +586,7 @@ def _add_to_typing_check(name, is_module=False): else: return "%s[%s]" % (element.__origin__, ", ".join(args_str)) elif isinstance(element, ForwardRef): - f_arg = element.__forward_arg__ + f_arg = self._get_module_name_alias(element.__forward_arg__) # if f_arg in ("Run", "Task"): # HACK -- forward references in current.py # _add_to_import("metaflow") # f_arg = "metaflow.%s" % f_arg @@ -377,9 +599,17 @@ def _add_to_typing_check(name, is_module=False): return "typing.NamedTuple" return str(element) else: - raise RuntimeError( - "Does not handle element %s of type %s" % (str(element), type(element)) - ) + if hasattr(element, "__module__"): + elem_module = self._get_module_name_alias(element.__module__) + if elem_module == "builtins": + return getattr(element, "__name__", str(element)) + _add_to_typing_check(elem_module, is_module=True) + return "{0}.{1}".format( + elem_module, getattr(element, "__name__", element) + ) + else: + # A constant + return str(element) def _exploit_annotation(self, annotation: Any, starting: str = ": ") -> str: annotation_string = "" @@ -390,23 +620,34 @@ def _exploit_annotation(self, annotation: Any, starting: str = ": ") -> str: return annotation_string def _generate_class_stub(self, name: str, clazz: type) -> str: + debug.stubgen_exec("Generating class stub for %s" % name) + skip_init = issubclass(clazz, (TriggeredRun, DeployedFlow)) + if issubclass(clazz, DeployerImpl): + if clazz.TYPE is not None: + clazz_type = clazz.TYPE.replace("-", "_") + self._deployer_injected_methods.setdefault(clazz_type, {})[ + "deployer" + ] = (self._current_module_name + "." + name) + buff = StringIO() # Class prototype buff.write("class " + name.split(".")[-1] + "(") # Add super classes for c in clazz.__bases__: - name_with_module = self._get_element_name_with_module(c) + name_with_module = self._get_element_name_with_module(c, force_import=True) buff.write(name_with_module + ", ") # Add metaclass - name_with_module = self._get_element_name_with_module(clazz.__class__) + name_with_module = self._get_element_name_with_module( + clazz.__class__, force_import=True + ) buff.write("metaclass=" + name_with_module + "):\n") # Add class docstring if clazz.__doc__: buff.write('%s"""\n' % TAB) - my_doc = cast(str, deindent_docstring(clazz.__doc__)) + my_doc = inspect.cleandoc(clazz.__doc__) init_blank = True for line in my_doc.split("\n"): if init_blank and len(line.strip()) == 0: @@ -429,6 +670,8 @@ def _generate_class_stub(self, name: str, clazz: type) -> str: func_deco = "@classmethod" element = element.__func__ if key == "__init__": + if skip_init: + continue init_func = element elif key == "__annotations__": annotation_dict = element @@ -436,11 +679,201 @@ def _generate_class_stub(self, name: str, clazz: type) -> str: if not element.__name__.startswith("_") or element.__name__.startswith( "__" ): - buff.write( - self._generate_function_stub( - key, element, indentation=TAB, deco=func_deco + if ( + clazz == Deployer + and element.__name__ in self._deployer_injected_methods + ): + # This is a method that was injected. It has docs but we need + # to parse it to generate the proper signature + func_doc = inspect.cleandoc(element.__doc__) + docs = split_docs( + func_doc, + [ + ("func_doc", StartEnd(0, 0)), + ( + "param_doc", + param_section_header.search(func_doc) + or StartEnd(len(func_doc), len(func_doc)), + ), + ( + "return_doc", + return_section_header.search(func_doc) + or StartEnd(len(func_doc), len(func_doc)), + ), + ], ) - ) + + parameters, _ = parse_params_from_doc(docs["param_doc"]) + return_type = self._deployer_injected_methods[element.__name__][ + "deployer" + ] + + buff.write( + self._generate_function_stub( + key, + element, + sign=[ + inspect.Signature( + parameters=[ + inspect.Parameter( + "self", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + + parameters, + return_annotation=return_type, + ) + ], + indentation=TAB, + deco=func_deco, + ) + ) + elif ( + clazz == DeployedFlow and element.__name__ == "from_deployment" + ): + # We simply update the signature to list the return + # type as a union of all possible deployers + func_doc = inspect.cleandoc(element.__doc__) + docs = split_docs( + func_doc, + [ + ("func_doc", StartEnd(0, 0)), + ( + "param_doc", + param_section_header.search(func_doc) + or StartEnd(len(func_doc), len(func_doc)), + ), + ( + "return_doc", + return_section_header.search(func_doc) + or StartEnd(len(func_doc), len(func_doc)), + ), + ], + ) + + parameters, _ = parse_params_from_doc(docs["param_doc"]) + + def _create_multi_type(*l): + return typing.Union[l] + + all_types = [ + v["from_deployment"][0] + for v in self._deployer_injected_methods.values() + ] + + if len(all_types) > 1: + return_type = _create_multi_type(*all_types) + else: + return_type = all_types[0] if len(all_types) else None + + buff.write( + self._generate_function_stub( + key, + element, + sign=[ + inspect.Signature( + parameters=[ + inspect.Parameter( + "cls", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + + parameters, + return_annotation=return_type, + ) + ], + indentation=TAB, + doc=docs["func_doc"] + + "\n\nParameters\n----------\n" + + docs["param_doc"] + + "\n\nReturns\n-------\n" + + "%s\nA `DeployedFlow` object" % str(return_type), + deco=func_deco, + ) + ) + elif ( + clazz == DeployedFlow + and element.__name__.startswith("from_") + and element.__name__[5:] in self._deployer_injected_methods + ): + # Get the doc from the from_deployment method stored in + # self._deployer_injected_methods + func_doc = inspect.cleandoc( + self._deployer_injected_methods[element.__name__[5:]][ + "from_deployment" + ][1] + or "" + ) + docs = split_docs( + func_doc, + [ + ("func_doc", StartEnd(0, 0)), + ( + "param_doc", + param_section_header.search(func_doc) + or StartEnd(len(func_doc), len(func_doc)), + ), + ( + "return_doc", + return_section_header.search(func_doc) + or StartEnd(len(func_doc), len(func_doc)), + ), + ], + ) + + parameters, _ = parse_params_from_doc(docs["param_doc"]) + return_type = self._deployer_injected_methods[ + element.__name__[5:] + ]["from_deployment"][0] + + buff.write( + self._generate_function_stub( + key, + element, + sign=[ + inspect.Signature( + parameters=[ + inspect.Parameter( + "cls", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + + parameters, + return_annotation=return_type, + ) + ], + indentation=TAB, + doc=docs["func_doc"] + + "\n\nParameters\n----------\n" + + docs["param_doc"] + + "\n\nReturns\n-------\n" + + docs["return_doc"], + deco=func_deco, + ) + ) + else: + if ( + issubclass(clazz, DeployedFlow) + and clazz.TYPE is not None + and key == "from_deployment" + ): + clazz_type = clazz.TYPE.replace("-", "_") + # Record docstring for this function + self._deployer_injected_methods.setdefault(clazz_type, {})[ + "from_deployment" + ] = ( + self._current_module_name + "." + name, + element.__doc__, + ) + buff.write( + self._generate_function_stub( + key, + element, + indentation=TAB, + deco=func_deco, + ) + ) + elif isinstance(element, property): if element.fget: buff.write( @@ -455,20 +888,17 @@ def _generate_class_stub(self, name: str, clazz: type) -> str: ) ) - # Special handling for the current module - if ( - self._current_module_name == METAFLOW_CURRENT_MODULE_NAME - and name == "Current" - ): + # Special handling of classes that have injected methods + if clazz == Current: # Multiple decorators can add the same object (trigger and trigger_on_finish) # as examples so we sort it out. resulting_dict = ( dict() ) # type Dict[str, List[inspect.Signature, str, List[str]]] - for project_name, addl_current in self._addl_current.items(): + for deco_name, addl_current in self._addl_current.items(): for name, (sign, doc) in addl_current.items(): r = resulting_dict.setdefault(name, [sign, doc, []]) - r[2].append("@%s" % project_name) + r[2].append("@%s" % deco_name) for name, (sign, doc, decos) in resulting_dict.items(): buff.write( self._generate_function_stub( @@ -481,7 +911,8 @@ def _generate_class_stub(self, name: str, clazz: type) -> str: deco="@property", ) ) - if init_func is None and annotation_dict: + + if not skip_init and init_func is None and annotation_dict: buff.write( self._generate_function_stub( "__init__", @@ -527,121 +958,31 @@ def _extract_signature_from_decorator( self._typevars["StepFlag"] = StepFlag raw_doc = inspect.cleandoc(raw_doc) - has_parameters = param_section_header.search(raw_doc) - has_add_to_current = add_to_current_header.search(raw_doc) - - if has_parameters and has_add_to_current: - doc = raw_doc[has_parameters.end() : has_add_to_current.start()] - add_to_current_doc = raw_doc[has_add_to_current.end() :] - raw_doc = raw_doc[: has_add_to_current.start()] - elif has_parameters: - doc = raw_doc[has_parameters.end() :] - add_to_current_doc = None - elif has_add_to_current: - add_to_current_doc = raw_doc[has_add_to_current.end() :] - raw_doc = raw_doc[: has_add_to_current.start()] - doc = "" - else: - doc = "" - add_to_current_doc = None - parameters = [] - no_arg_version = True - for line in doc.splitlines(): - if non_indented_line.match(line): - match = param_name_type.match(line) - arg_name = type_name = is_optional = default = None - default_set = False - if match is not None: - arg_name = match.group("name") - type_name = match.group("type") - if type_name is not None: - type_detail = type_annotations.match(type_name) - if type_detail is not None: - type_name = type_detail.group("type") - is_optional = type_detail.group("optional") is not None - default = type_detail.group("default") - if default: - default_set = True - try: - default = eval(default) - except: - pass - try: - type_name = eval(type_name) - except: - pass - parameters.append( - inspect.Parameter( - name=arg_name, - kind=inspect.Parameter.KEYWORD_ONLY, - default=( - default - if default_set - else None if is_optional else inspect.Parameter.empty - ), - annotation=( - Optional[type_name] if is_optional else type_name - ), - ) - ) - if not default_set: - # If we don't have a default set for any parameter, we can't - # have a no-arg version since the decorator would be incomplete - no_arg_version = False - if add_to_current_doc: - current_property = None - current_return_type = None - current_property_indent = None - current_doc = [] - add_to_current = dict() # type: Dict[str, Tuple[inspect.Signature, str]] - - def _add(): - if current_property: - add_to_current[current_property] = ( - inspect.Signature( - [ - inspect.Parameter( - "self", inspect.Parameter.POSITIONAL_OR_KEYWORD - ) - ], - return_annotation=current_return_type, - ), - "\n".join(current_doc), - ) - - for line in add_to_current_doc.splitlines(): - # Parse stanzas that look like the following: - # -> type - # indented doc string - if current_property_indent is not None and ( - line.startswith(current_property_indent + " ") or line.strip() == "" - ): - offset = len(current_property_indent) - if line.lstrip().startswith("@@ "): - line = line.replace("@@ ", "") - current_doc.append(line[offset:].rstrip()) - else: - if line.strip() == 0: - continue - if current_property: - # Ends a property stanza - _add() - # Now start a new one - line = line.rstrip() - current_property_indent = line[: len(line) - len(line.lstrip())] - # This is a line so we split it using "->" - current_property, current_return_type = line.split("->") - current_property = current_property.strip() - current_return_type = current_return_type.strip() - current_doc = [] - _add() - - self._addl_current[name] = add_to_current + section_boundaries = [ + ("func_doc", StartEnd(0, 0)), + ( + "param_doc", + param_section_header.search(raw_doc) + or StartEnd(len(raw_doc), len(raw_doc)), + ), + ( + "add_to_current_doc", + add_to_current_header.search(raw_doc) + or StartEnd(len(raw_doc), len(raw_doc)), + ), + ] + + docs = split_docs(raw_doc, section_boundaries) + + parameters, no_arg_version = parse_params_from_doc(docs["param_doc"]) + + if docs["add_to_current_doc"]: + self._addl_current[name] = parse_add_to_docs(docs["add_to_current_doc"]) result = [] if no_arg_version: if is_flow_decorator: - if has_parameters: + if docs["param_doc"]: result.append( ( inspect.Signature( @@ -670,7 +1011,7 @@ def _add(): ), ) else: - if has_parameters: + if docs["param_doc"]: result.append( ( inspect.Signature( @@ -792,8 +1133,8 @@ def _add(): result = result[1:] # Add doc to first and last overloads. Jedi uses the last one and pycharm # the first one. Go figure. - result[0] = (result[0][0], raw_doc) - result[-1] = (result[-1][0], raw_doc) + result[0] = (result[0][0], docs["func_doc"]) + result[-1] = (result[-1][0], docs["func_doc"]) return result def _generate_function_stub( @@ -805,11 +1146,12 @@ def _generate_function_stub( doc: Optional[str] = None, deco: Optional[str] = None, ) -> str: + debug.stubgen_exec("Generating function stub for %s" % name) + def exploit_default(default_value: Any) -> Optional[str]: - if ( - default_value != inspect.Parameter.empty - and type(default_value).__module__ == "builtins" - ): + if default_value == inspect.Parameter.empty: + return None + if type(default_value).__module__ == "builtins": if isinstance(default_value, list): return ( "[" @@ -839,22 +1181,23 @@ def exploit_default(default_value: Any) -> Optional[str]: ) + "}" ) - elif str(default_value).startswith("<"): - if default_value.__module__ == "builtins": - return default_value.__name__ - else: - self._typing_imports.add(default_value.__module__) - return ".".join( - [default_value.__module__, default_value.__name__] - ) + elif isinstance(default_value, str): + return "'" + default_value + "'" else: - return ( - str(default_value) - if not isinstance(default_value, str) - else '"' + default_value + '"' - ) + return self._get_element_name_with_module(default_value) + + elif str(default_value).startswith("<"): + if default_value.__module__ == "builtins": + return default_value.__name__ + else: + self._typing_imports.add(default_value.__module__) + return ".".join([default_value.__module__, default_value.__name__]) else: - return None + return ( + str(default_value) + if not isinstance(default_value, str) + else '"' + default_value + '"' + ) buff = StringIO() if sign is None and func is None: @@ -870,6 +1213,10 @@ def exploit_default(default_value: Any) -> Optional[str]: # value return "" doc = doc or func.__doc__ + if doc == "STUBGEN_IGNORE": + # Ignore methods that have STUBGEN_IGNORE. Used to ignore certain + # methods for the Deployer + return "" indentation = indentation or "" # Deal with overload annotations -- the last one will be non overloaded and @@ -883,6 +1230,9 @@ def exploit_default(default_value: Any) -> Optional[str]: buff.write("\n") if do_overload and count < len(sign) - 1: + # According to mypy, we should have this on all variants but + # some IDEs seem to prefer if there is one non-overloaded + # This also changes our checks so if changing, modify tests buff.write(indentation + "@typing.overload\n") if deco: buff.write(indentation + deco + "\n") @@ -890,6 +1240,7 @@ def exploit_default(default_value: Any) -> Optional[str]: kw_only_param = False for i, (par_name, parameter) in enumerate(my_sign.parameters.items()): annotation = self._exploit_annotation(parameter.annotation) + default = exploit_default(parameter.default) if kw_only_param and parameter.kind != inspect.Parameter.KEYWORD_ONLY: @@ -922,7 +1273,7 @@ def exploit_default(default_value: Any) -> Optional[str]: if (count == 0 or count == len(sign) - 1) and doc is not None: buff.write('%s%s"""\n' % (indentation, TAB)) - my_doc = cast(str, deindent_docstring(doc)) + my_doc = inspect.cleandoc(doc) init_blank = True for line in my_doc.split("\n"): if init_blank and len(line.strip()) == 0: @@ -941,6 +1292,7 @@ def _generate_generic_stub(self, element_name: str, element: Any) -> str: def _generate_stubs(self): for name, attr in self._current_objects.items(): self._current_parent_module = inspect.getmodule(attr) + self._current_name = name if inspect.isclass(attr): self._stubs.append(self._generate_class_stub(name, attr)) elif inspect.isfunction(attr): @@ -1023,6 +1375,29 @@ def _generate_stubs(self): elif not inspect.ismodule(attr): self._stubs.append(self._generate_generic_stub(name, attr)) + def _write_header(self, f, width): + title_line = "Auto-generated Metaflow stub file" + title_white_space = (width - len(title_line)) / 2 + title_line = "#%s%s%s#\n" % ( + " " * math.floor(title_white_space), + title_line, + " " * math.ceil(title_white_space), + ) + f.write( + "#" * (width + 2) + + "\n" + + title_line + + "# MF version: %s%s#\n" + % (self._mf_version, " " * (width - 13 - len(self._mf_version))) + + "# Generated on %s%s#\n" + % ( + datetime.fromtimestamp(time.time()).isoformat(), + " " * (width - 14 - 26), + ) + + "#" * (width + 2) + + "\n\n" + ) + def write_out(self): out_dir = self._output_dir os.makedirs(out_dir, exist_ok=True) @@ -1036,66 +1411,75 @@ def write_out(self): "%s %s" % (self._mf_version, datetime.fromtimestamp(time.time()).isoformat()) ) - while len(self._pending_modules) != 0: - module_name = self._pending_modules.pop(0) + post_process_modules = [] + is_post_processing = False + while len(self._pending_modules) != 0 or len(post_process_modules) != 0: + if is_post_processing or len(self._pending_modules) == 0: + is_post_processing = True + module_alias, module_name = post_process_modules.pop(0) + else: + module_alias, module_name = self._pending_modules.pop(0) # Skip vendored stuff - if module_name.startswith("metaflow._vendor"): + if module_alias.startswith("metaflow._vendor") or module_name.startswith( + "metaflow._vendor" + ): continue - # We delay current module + # We delay current module and deployer module to the end since they + # depend on info we gather elsewhere if ( - module_name == METAFLOW_CURRENT_MODULE_NAME - and len(set(self._pending_modules)) > 1 + module_alias + in ( + METAFLOW_CURRENT_MODULE_NAME, + METAFLOW_DEPLOYER_MODULE_NAME, + ) + and len(self._pending_modules) != 0 ): - self._pending_modules.append(module_name) + post_process_modules.append((module_alias, module_name)) continue - if module_name in self._done_modules: + if module_alias in self._done_modules: continue - self._done_modules.add(module_name) + self._done_modules.add(module_alias) # If not, we process the module self._reset() - self._get_module(module_name) + self._get_module(module_alias, module_name) + if module_name == "metaflow" and not is_post_processing: + # We will want to regenerate this at the end to take into account + # any changes to the Deployer + post_process_modules.append((module_name, module_name)) + self._done_modules.remove(module_name) + continue self._generate_stubs() if hasattr(self._current_module, "__path__"): # This is a package (so a directory) and we are dealing with # a __init__.pyi type of case - dir_path = os.path.join( - self._output_dir, *self._current_module.__name__.split(".")[1:] - ) + dir_path = os.path.join(self._output_dir, *module_alias.split(".")[1:]) else: # This is NOT a package so the original source file is not a __init__.py dir_path = os.path.join( - self._output_dir, *self._current_module.__name__.split(".")[1:-1] + self._output_dir, *module_alias.split(".")[1:-1] ) out_file = os.path.join( dir_path, os.path.basename(self._current_module.__file__) + "i" ) + width = 100 + os.makedirs(os.path.dirname(out_file), exist_ok=True) + # We want to make sure we always have a __init__.pyi in the directories + # we are creating + parts = dir_path.split(os.sep)[len(self._output_dir.split(os.sep)) :] + for i in range(1, len(parts) + 1): + init_file_path = os.path.join( + self._output_dir, *parts[:i], "__init__.pyi" + ) + if not os.path.exists(init_file_path): + with open(init_file_path, mode="w", encoding="utf-8") as f: + self._write_header(f, width) - width = 80 - title_line = "Auto-generated Metaflow stub file" - title_white_space = (width - len(title_line)) / 2 - title_line = "#%s%s%s#\n" % ( - " " * math.floor(title_white_space), - title_line, - " " * math.ceil(title_white_space), - ) with open(out_file, mode="w", encoding="utf-8") as f: - f.write( - "#" * (width + 2) - + "\n" - + title_line - + "# MF version: %s%s#\n" - % (self._mf_version, " " * (width - 13 - len(self._mf_version))) - + "# Generated on %s%s#\n" - % ( - datetime.fromtimestamp(time.time()).isoformat(), - " " * (width - 14 - 26), - ) - + "#" * (width + 2) - + "\n\n" - ) + self._write_header(f, width) + f.write("from __future__ import annotations\n\n") imported_typing = False for module in self._imports: @@ -1123,8 +1507,14 @@ def write_out(self): "%s = %s\n" % (type_name, new_type_to_str(type_var)) ) f.write("\n") + for import_line in self._current_references: + f.write(import_line + "\n") + f.write("\n") for stub in self._stubs: f.write(stub + "\n") + if is_post_processing: + # Don't consider any pending modules if we are post processing + self._pending_modules.clear() if __name__ == "__main__": diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index f1b8185d52e..1ea06167498 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -10,7 +10,7 @@ from .. import metaflow_config from ..exception import MetaflowInternalError -from ..metadata import DataArtifact, MetaDatum +from ..metadata_provider import DataArtifact, MetaDatum from ..parameters import Parameter from ..util import Path, is_stringish, to_fileobj diff --git a/metaflow/extension_support/plugins.py b/metaflow/extension_support/plugins.py index 9dec577600c..9472202c510 100644 --- a/metaflow/extension_support/plugins.py +++ b/metaflow/extension_support/plugins.py @@ -178,6 +178,7 @@ def resolve_plugins(category): "environment": lambda x: x.TYPE, "metadata_provider": lambda x: x.TYPE, "datastore": lambda x: x.TYPE, + "dataclient": lambda x: x.TYPE, "secrets_provider": lambda x: x.TYPE, "gcp_client_provider": lambda x: x.name, "deployer_impl_provider": lambda x: x.TYPE, diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index e21b9597009..0c7ffd1f128 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -64,7 +64,7 @@ def __getitem__(self, item): return item or 0 # item is None for the control task, but it is also split 0 -class _FlowSpecMeta(type): +class FlowSpecMeta(type): def __new__(cls, name, bases, dct): f = super().__new__(cls, name, bases, dct) # This makes sure to give _flow_decorators to each @@ -75,7 +75,7 @@ def __new__(cls, name, bases, dct): return f -class FlowSpec(metaclass=_FlowSpecMeta): +class FlowSpec(metaclass=FlowSpecMeta): """ Main class from which all Flows should inherit. diff --git a/metaflow/includefile.py b/metaflow/includefile.py index c035252c706..4bc16172863 100644 --- a/metaflow/includefile.py +++ b/metaflow/includefile.py @@ -1,6 +1,7 @@ from collections import namedtuple import gzip +import importlib import io import json import os @@ -17,6 +18,8 @@ Parameter, ParameterContext, ) + +from .plugins import DATACLIENTS from .util import get_username import functools @@ -47,16 +50,7 @@ # From here on out, this is the IncludeFile implementation. -from metaflow.plugins.datatools import Local, S3 -from metaflow.plugins.azure.includefile_support import Azure -from metaflow.plugins.gcp.includefile_support import GS - -DATACLIENTS = { - "local": Local, - "s3": S3, - "azure": Azure, - "gs": GS, -} +_dict_dataclients = {d.TYPE: d for d in DATACLIENTS} class IncludedFile(object): @@ -167,7 +161,7 @@ def convert(self, value, param, ctx): "IncludeFile using a direct reference to a file in cloud storage is no " "longer supported. Contact the Metaflow team if you need this supported" ) - # if DATACLIENTS.get(path[:prefix_pos]) is None: + # if _dict_dataclients.get(path[:prefix_pos]) is None: # self.fail( # "IncludeFile: no handler for external file of type '%s' " # "(given path is '%s')" % (path[:prefix_pos], path) @@ -187,7 +181,7 @@ def convert(self, value, param, ctx): pass except OSError: self.fail("IncludeFile: could not open file '%s' for reading" % path) - handler = DATACLIENTS.get(ctx.ds_type) + handler = _dict_dataclients.get(ctx.ds_type) if handler is None: self.fail( "IncludeFile: no data-client for datastore of type '%s'" @@ -213,7 +207,7 @@ def _delayed_eval_func(ctx=lambda_ctx, return_str=False): ctx.path, ctx.is_text, ctx.encoding, - DATACLIENTS[ctx.handler_type], + _dict_dataclients[ctx.handler_type], ctx.echo, ) ) @@ -425,7 +419,7 @@ def _get_handler(url): if prefix_pos < 0: raise MetaflowException("Malformed URL: '%s'" % url) prefix = url[:prefix_pos] - handler = DATACLIENTS.get(prefix) + handler = _dict_dataclients.get(prefix) if handler is None: raise MetaflowException("Could not find data client for '%s'" % prefix) return handler diff --git a/metaflow/metadata/__init__.py b/metaflow/metadata_provider/__init__.py similarity index 100% rename from metaflow/metadata/__init__.py rename to metaflow/metadata_provider/__init__.py diff --git a/metaflow/metadata/heartbeat.py b/metaflow/metadata_provider/heartbeat.py similarity index 100% rename from metaflow/metadata/heartbeat.py rename to metaflow/metadata_provider/heartbeat.py diff --git a/metaflow/metadata/metadata.py b/metaflow/metadata_provider/metadata.py similarity index 100% rename from metaflow/metadata/metadata.py rename to metaflow/metadata_provider/metadata.py diff --git a/metaflow/metadata/util.py b/metaflow/metadata_provider/util.py similarity index 100% rename from metaflow/metadata/util.py rename to metaflow/metadata_provider/util.py diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 1d820ca0620..afc1208ae07 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -43,6 +43,10 @@ DEFAULT_SECRETS_BACKEND_TYPE = from_conf("DEFAULT_SECRETS_BACKEND_TYPE") DEFAULT_SECRETS_ROLE = from_conf("DEFAULT_SECRETS_ROLE") +DEFAULT_FROM_DEPLOYMENT_IMPL = from_conf( + "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" +) + ### # User configuration ### diff --git a/metaflow/metaflow_current.py b/metaflow/metaflow_current.py index 0ca590c24f6..8443c1d75ab 100644 --- a/metaflow/metaflow_current.py +++ b/metaflow/metaflow_current.py @@ -30,7 +30,7 @@ def _raise(ex): raise ex self.__class__.graph = property( - fget=lambda _: _raise(RuntimeError("Graph is not available")) + fget=lambda self: _raise(RuntimeError("Graph is not available")) ) def _set_env( diff --git a/metaflow/parameters.py b/metaflow/parameters.py index eca634e7f6a..fe0dabbda3f 100644 --- a/metaflow/parameters.py +++ b/metaflow/parameters.py @@ -438,3 +438,6 @@ def wrapper(cmd): return cmd return wrapper + + +JSONType = JSONTypeClass() diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index a1a7f593a5b..f2a4d7cdb43 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -73,8 +73,8 @@ # Add metadata providers here METADATA_PROVIDERS_DESC = [ - ("service", ".metadata.service.ServiceMetadataProvider"), - ("local", ".metadata.local.LocalMetadataProvider"), + ("service", ".metadata_providers.service.ServiceMetadataProvider"), + ("local", ".metadata_providers.local.LocalMetadataProvider"), ] # Add datastore here @@ -85,13 +85,21 @@ ("gs", ".datastores.gs_storage.GSStorage"), ] +# Dataclients are used for IncludeFile +DATACLIENTS_DESC = [ + ("local", ".datatools.Local"), + ("s3", ".datatools.S3"), + ("azure", ".azure.includefile_support.Azure"), + ("gs", ".gcp.includefile_support.GS"), +] + # Add non monitoring/logging sidecars here SIDECARS_DESC = [ ( "save_logs_periodically", "..mflog.save_logs_periodically.SaveLogsPeriodicallySidecar", ), - ("heartbeat", "metaflow.metadata.heartbeat.MetadataHeartBeat"), + ("heartbeat", "metaflow.metadata_provider.heartbeat.MetadataHeartBeat"), ] # Add logging sidecars here @@ -161,6 +169,7 @@ def get_plugin_cli(): ENVIRONMENTS = resolve_plugins("environment") METADATA_PROVIDERS = resolve_plugins("metadata_provider") DATASTORES = resolve_plugins("datastore") +DATACLIENTS = resolve_plugins("dataclient") SIDECARS = resolve_plugins("sidecar") LOGGING_SIDECARS = resolve_plugins("logging_sidecar") MONITOR_SIDECARS = resolve_plugins("monitor_sidecar") diff --git a/metaflow/plugins/airflow/airflow_decorator.py b/metaflow/plugins/airflow/airflow_decorator.py index 11bdecaaa8b..01400867de8 100644 --- a/metaflow/plugins/airflow/airflow_decorator.py +++ b/metaflow/plugins/airflow/airflow_decorator.py @@ -1,7 +1,7 @@ import json import os from metaflow.decorators import StepDecorator -from metaflow.metadata import MetaDatum +from metaflow.metadata_provider import MetaDatum from .airflow_utils import ( TASK_ID_XCOM_KEY, diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index 65335e9b6f7..c3676edc0c4 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -6,7 +6,7 @@ from metaflow import current from metaflow.decorators import StepDecorator from metaflow.events import Trigger -from metaflow.metadata import MetaDatum +from metaflow.metadata_provider import MetaDatum from metaflow.metaflow_config import ARGO_EVENTS_WEBHOOK_URL from metaflow.graph import DAGNode, FlowGraph from metaflow.flowspec import FlowSpec diff --git a/metaflow/plugins/argo/argo_workflows_deployer.py b/metaflow/plugins/argo/argo_workflows_deployer.py index 1a3056bc28d..0c9da919633 100644 --- a/metaflow/plugins/argo/argo_workflows_deployer.py +++ b/metaflow/plugins/argo/argo_workflows_deployer.py @@ -1,292 +1,106 @@ -import sys -import tempfile -from typing import Optional, ClassVar +from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type -from metaflow.plugins.argo.argo_workflows import ArgoWorkflows -from metaflow.runner.deployer import ( - DeployerImpl, - DeployedFlow, - TriggeredRun, - get_lower_level_group, - handle_timeout, -) +from metaflow.runner.deployer_impl import DeployerImpl - -def suspend(instance: TriggeredRun, **kwargs): - """ - Suspend the running workflow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the suspend command. - - Returns - ------- - bool - True if the command was successful, False otherwise. - """ - _, run_id = instance.pathspec.split("/") - - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).suspend(run_id=run_id, **kwargs) - - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - return command_obj.process.returncode == 0 - - -def unsuspend(instance: TriggeredRun, **kwargs): - """ - Unsuspend the suspended workflow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the unsuspend command. - - Returns - ------- - bool - True if the command was successful, False otherwise. - """ - _, run_id = instance.pathspec.split("/") - - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).unsuspend(run_id=run_id, **kwargs) - - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - return command_obj.process.returncode == 0 - - -def terminate(instance: TriggeredRun, **kwargs): - """ - Terminate the running workflow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the terminate command. - - Returns - ------- - bool - True if the command was successful, False otherwise. - """ - _, run_id = instance.pathspec.split("/") - - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).terminate(run_id=run_id, **kwargs) - - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - return command_obj.process.returncode == 0 - - -def status(instance: TriggeredRun): - """ - Get the status of the triggered run. - - Returns - ------- - str, optional - The status of the workflow considering the run object, or None if the status could not be retrieved. - """ - from metaflow.plugins.argo.argo_workflows_cli import ( - get_status_considering_run_object, - ) - - flow_name, run_id = instance.pathspec.split("/") - name = run_id[5:] - status = ArgoWorkflows.get_workflow_status(flow_name, name) - if status is not None: - return get_status_considering_run_object(status, instance.run) - return None - - -def production_token(instance: DeployedFlow): - """ - Get the production token for the deployed flow. - - Returns - ------- - str, optional - The production token, None if it cannot be retrieved. - """ - try: - _, production_token = ArgoWorkflows.get_existing_deployment( - instance.deployer.name - ) - return production_token - except TypeError: - return None - - -def delete(instance: DeployedFlow, **kwargs): - """ - Delete the deployed flow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the delete command. - - Returns - ------- - bool - True if the command was successful, False otherwise. - """ - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).delete(**kwargs) - - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - return command_obj.process.returncode == 0 - - -def trigger(instance: DeployedFlow, **kwargs): - """ - Trigger a new run for the deployed flow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the trigger command, `Parameters` in particular - - Returns - ------- - ArgoWorkflowsTriggeredRun - The triggered run instance. - - Raises - ------ - Exception - If there is an error during the trigger process. - """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) - - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) - - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - content = handle_timeout( - tfp_runner_attribute, command_obj, instance.deployer.file_read_timeout - ) - - if command_obj.process.returncode == 0: - triggered_run = TriggeredRun(deployer=instance.deployer, content=content) - triggered_run._enrich_object( - { - "status": property(status), - "terminate": terminate, - "suspend": suspend, - "unsuspend": unsuspend, - } - ) - return triggered_run - - raise Exception( - "Error triggering %s on %s for %s" - % (instance.deployer.name, instance.deployer.TYPE, instance.deployer.flow_file) - ) +if TYPE_CHECKING: + import metaflow.plugins.argo.argo_workflows_deployer_objects class ArgoWorkflowsDeployer(DeployerImpl): """ Deployer implementation for Argo Workflows. - Attributes + Parameters ---------- - TYPE : ClassVar[Optional[str]] - The type of the deployer, which is "argo-workflows". + name : str, optional, default None + Argo workflow name. The flow name is used instead if this option is not specified. """ TYPE: ClassVar[Optional[str]] = "argo-workflows" - def __init__(self, deployer_kwargs, **kwargs): + def __init__(self, deployer_kwargs: Dict[str, str], **kwargs): """ Initialize the ArgoWorkflowsDeployer. Parameters ---------- - deployer_kwargs : dict + deployer_kwargs : Dict[str, str] The deployer-specific keyword arguments. **kwargs : Any Additional arguments to pass to the superclass constructor. """ - self.deployer_kwargs = deployer_kwargs + self._deployer_kwargs = deployer_kwargs super().__init__(**kwargs) - def _enrich_deployed_flow(self, deployed_flow: DeployedFlow): + @property + def deployer_kwargs(self) -> Dict[str, Any]: + return self._deployer_kwargs + + @staticmethod + def deployed_flow_type() -> ( + Type[ + "metaflow.plugins.argo.argo_workflows_deployer_objects.ArgoWorkflowsDeployedFlow" + ] + ): + from .argo_workflows_deployer_objects import ArgoWorkflowsDeployedFlow + + return ArgoWorkflowsDeployedFlow + + def create( + self, **kwargs + ) -> "metaflow.plugins.argo.argo_workflows_deployer_objects.ArgoWorkflowsDeployedFlow": """ - Enrich the DeployedFlow object with additional properties and methods. + Create a new ArgoWorkflow deployment. Parameters ---------- - deployed_flow : DeployedFlow - The deployed flow object to enrich. + authorize : str, optional, default None + Authorize using this production token. Required when re-deploying an existing flow + for the first time. The token is cached in METAFLOW_HOME. + generate_new_token : bool, optional, default False + Generate a new production token for this flow. Moves the production flow to a new namespace. + given_token : str, optional, default None + Use the given production token for this flow. Moves the production flow to the given namespace. + tags : List[str], optional, default None + Annotate all objects produced by Argo Workflows runs with these tags. + user_namespace : str, optional, default None + Change the namespace from the default (production token) to the given tag. + only_json : bool, optional, default False + Only print out JSON sent to Argo Workflows without deploying anything. + max_workers : int, optional, default 100 + Maximum number of parallel processes. + workflow_timeout : int, optional, default None + Workflow timeout in seconds. + workflow_priority : int, optional, default None + Workflow priority as an integer. Higher priority workflows are processed first + if Argo Workflows controller is configured to process limited parallel workflows. + auto_emit_argo_events : bool, optional, default True + Auto emits Argo Events when the run completes successfully. + notify_on_error : bool, optional, default False + Notify if the workflow fails. + notify_on_success : bool, optional, default False + Notify if the workflow succeeds. + notify_slack_webhook_url : str, optional, default '' + Slack incoming webhook url for workflow success/failure notifications. + notify_pager_duty_integration_key : str, optional, default '' + PagerDuty Events API V2 Integration key for workflow success/failure notifications. + enable_heartbeat_daemon : bool, optional, default False + Use a daemon container to broadcast heartbeats. + deployer_attribute_file : str, optional, default None + Write the workflow name to the specified file. Used internally for Metaflow's Deployer API. + enable_error_msg_capture : bool, optional, default True + Capture stack trace of first failed task in exit hook. + + Returns + ------- + ArgoWorkflowsDeployedFlow + The Flow deployed to Argo Workflows. """ - deployed_flow._enrich_object( - { - "production_token": property(production_token), - "trigger": trigger, - "delete": delete, - } - ) + + # Prevent circular import + from .argo_workflows_deployer_objects import ArgoWorkflowsDeployedFlow + + return self._create(ArgoWorkflowsDeployedFlow, **kwargs) + + +_addl_stubgen_modules = ["metaflow.plugins.argo.argo_workflows_deployer_objects"] diff --git a/metaflow/plugins/argo/argo_workflows_deployer_objects.py b/metaflow/plugins/argo/argo_workflows_deployer_objects.py new file mode 100644 index 00000000000..6538b70310b --- /dev/null +++ b/metaflow/plugins/argo/argo_workflows_deployer_objects.py @@ -0,0 +1,381 @@ +import sys +import json +import tempfile +from typing import ClassVar, Optional + +from metaflow.client.core import get_metadata +from metaflow.exception import MetaflowException +from metaflow.plugins.argo.argo_client import ArgoClient +from metaflow.metaflow_config import KUBERNETES_NAMESPACE +from metaflow.plugins.argo.argo_workflows import ArgoWorkflows +from metaflow.runner.deployer import Deployer, DeployedFlow, TriggeredRun + +from metaflow.runner.utils import get_lower_level_group, handle_timeout + + +def generate_fake_flow_file_contents( + flow_name: str, param_info: dict, project_name: Optional[str] = None +): + params_code = "" + for _, param_details in param_info.items(): + param_name = param_details["name"] + param_type = param_details["type"] + param_help = param_details["description"] + param_required = param_details["is_required"] + + if param_type == "JSON": + params_code += ( + f" {param_name} = Parameter('{param_name}', " + f"type=JSONType, help='{param_help}', required={param_required})\n" + ) + elif param_type == "FilePath": + is_text = param_details.get("is_text", True) + encoding = param_details.get("encoding", "utf-8") + params_code += ( + f" {param_name} = IncludeFile('{param_name}', " + f"is_text={is_text}, encoding='{encoding}', help='{param_help}', " + f"required={param_required})\n" + ) + else: + params_code += ( + f" {param_name} = Parameter('{param_name}', " + f"type={param_type}, help='{param_help}', required={param_required})\n" + ) + + project_decorator = f"@project(name='{project_name}')\n" if project_name else "" + + contents = f"""\ +from metaflow import FlowSpec, Parameter, IncludeFile, JSONType, step, project +{project_decorator}class {flow_name}(FlowSpec): +{params_code} + @step + def start(self): + self.next(self.end) + @step + def end(self): + pass +if __name__ == '__main__': + {flow_name}() +""" + return contents + + +class ArgoWorkflowsTriggeredRun(TriggeredRun): + """ + A class representing a triggered Argo Workflow execution. + """ + + def suspend(self, **kwargs) -> bool: + """ + Suspend the running workflow. + + Parameters + ---------- + authorize : str, optional, default None + Authorize the suspension with a production token. + + Returns + ------- + bool + True if the command was successful, False otherwise. + """ + _, run_id = self.pathspec.split("/") + + # every subclass needs to have `self.deployer_kwargs` + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).suspend(run_id=run_id, **kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + return command_obj.process.returncode == 0 + + def unsuspend(self, **kwargs) -> bool: + """ + Unsuspend the suspended workflow. + + Parameters + ---------- + authorize : str, optional, default None + Authorize the unsuspend with a production token. + + Returns + ------- + bool + True if the command was successful, False otherwise. + """ + _, run_id = self.pathspec.split("/") + + # every subclass needs to have `self.deployer_kwargs` + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).unsuspend(run_id=run_id, **kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + return command_obj.process.returncode == 0 + + def terminate(self, **kwargs) -> bool: + """ + Terminate the running workflow. + + Parameters + ---------- + authorize : str, optional, default None + Authorize the termination with a production token. + + Returns + ------- + bool + True if the command was successful, False otherwise. + """ + _, run_id = self.pathspec.split("/") + + # every subclass needs to have `self.deployer_kwargs` + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).terminate(run_id=run_id, **kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + return command_obj.process.returncode == 0 + + @property + def status(self) -> Optional[str]: + """ + Get the status of the triggered run. + + Returns + ------- + str, optional + The status of the workflow considering the run object, or None if + the status could not be retrieved. + """ + from metaflow.plugins.argo.argo_workflows_cli import ( + get_status_considering_run_object, + ) + + flow_name, run_id = self.pathspec.split("/") + name = run_id[5:] + status = ArgoWorkflows.get_workflow_status(flow_name, name) + if status is not None: + return get_status_considering_run_object(status, self.run) + return None + + +class ArgoWorkflowsDeployedFlow(DeployedFlow): + """ + A class representing a deployed Argo Workflow template. + """ + + TYPE: ClassVar[Optional[str]] = "argo-workflows" + + @classmethod + def from_deployment(cls, identifier: str, metadata: Optional[str] = None): + """ + Retrieves a `ArgoWorkflowsDeployedFlow` object from an identifier and optional + metadata. + + Parameters + ---------- + identifier : str + Deployer specific identifier for the workflow to retrieve + metadata : str, optional, default None + Optional deployer specific metadata. + + Returns + ------- + ArgoWorkflowsDeployedFlow + A `ArgoWorkflowsDeployedFlow` object representing the + deployed flow on argo workflows. + """ + client = ArgoClient(namespace=KUBERNETES_NAMESPACE) + workflow_template = client.get_workflow_template(identifier) + + if workflow_template is None: + raise MetaflowException("No deployed flow found for: %s" % identifier) + + metadata_annotations = workflow_template.get("metadata", {}).get( + "annotations", {} + ) + + flow_name = metadata_annotations.get("metaflow/flow_name", "") + username = metadata_annotations.get("metaflow/owner", "") + parameters = json.loads(metadata_annotations.get("metaflow/parameters", {})) + + # these two only exist if @project decorator is used.. + branch_name = metadata_annotations.get("metaflow/branch_name", None) + project_name = metadata_annotations.get("metaflow/project_name", None) + + project_kwargs = {} + if branch_name is not None: + if branch_name.startswith("prod."): + project_kwargs["production"] = True + project_kwargs["branch"] = branch_name[len("prod.") :] + elif branch_name.startswith("test."): + project_kwargs["branch"] = branch_name[len("test.") :] + elif branch_name == "prod": + project_kwargs["production"] = True + + fake_flow_file_contents = generate_fake_flow_file_contents( + flow_name=flow_name, param_info=parameters, project_name=project_name + ) + + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as fake_flow_file: + with open(fake_flow_file.name, "w") as fp: + fp.write(fake_flow_file_contents) + + if branch_name is not None: + d = Deployer( + fake_flow_file.name, + env={"METAFLOW_USER": username}, + **project_kwargs, + ).argo_workflows() + else: + d = Deployer( + fake_flow_file.name, env={"METAFLOW_USER": username} + ).argo_workflows(name=identifier) + + d.name = identifier + d.flow_name = flow_name + if metadata is None: + d.metadata = get_metadata() + else: + d.metadata = metadata + + return cls(deployer=d) + + @property + def production_token(self) -> Optional[str]: + """ + Get the production token for the deployed flow. + + Returns + ------- + str, optional + The production token, None if it cannot be retrieved. + """ + try: + _, production_token = ArgoWorkflows.get_existing_deployment( + self.deployer.name + ) + return production_token + except TypeError: + return None + + def delete(self, **kwargs) -> bool: + """ + Delete the deployed workflow template. + + Parameters + ---------- + authorize : str, optional, default None + Authorize the deletion with a production token. + + Returns + ------- + bool + True if the command was successful, False otherwise. + """ + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).delete(**kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + return command_obj.process.returncode == 0 + + def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun: + """ + Trigger a new run for the deployed flow. + + Parameters + ---------- + **kwargs : Any + Additional arguments to pass to the trigger command, + `Parameters` in particular. + + Returns + ------- + ArgoWorkflowsTriggeredRun + The triggered run instance. + + Raises + ------ + Exception + If there is an error during the trigger process. + """ + with tempfile.TemporaryDirectory() as temp_dir: + tfp_runner_attribute = tempfile.NamedTemporaryFile( + dir=temp_dir, delete=False + ) + + # every subclass needs to have `self.deployer_kwargs` + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + content = handle_timeout( + tfp_runner_attribute, command_obj, self.deployer.file_read_timeout + ) + + if command_obj.process.returncode == 0: + return ArgoWorkflowsTriggeredRun( + deployer=self.deployer, content=content + ) + + raise Exception( + "Error triggering %s on %s for %s" + % ( + self.deployer.name, + self.deployer.TYPE, + self.deployer.flow_file, + ) + ) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 7f120949006..d99c04e7be1 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -7,7 +7,7 @@ from metaflow import util from metaflow import R from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY -from metaflow.metadata.util import sync_local_metadata_from_datastore +from metaflow.metadata_provider.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow.mflog import TASK_LOG_SOURCE from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 71a36ee2e16..52291d49cd5 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -10,8 +10,8 @@ from metaflow.decorators import StepDecorator from metaflow.plugins.resources_decorator import ResourcesDecorator from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task -from metaflow.metadata import MetaDatum -from metaflow.metadata.util import sync_local_metadata_to_datastore +from metaflow.metadata_provider import MetaDatum +from metaflow.metadata_provider.util import sync_local_metadata_to_datastore from metaflow.metaflow_config import ( ECS_S3_ACCESS_IAM_ROLE, BATCH_JOB_QUEUE, diff --git a/metaflow/plugins/aws/step_functions/step_functions_decorator.py b/metaflow/plugins/aws/step_functions/step_functions_decorator.py index 89a7de79857..9754a67a443 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_decorator.py +++ b/metaflow/plugins/aws/step_functions/step_functions_decorator.py @@ -3,7 +3,7 @@ import time from metaflow.decorators import StepDecorator -from metaflow.metadata import MetaDatum +from metaflow.metadata_provider import MetaDatum from .dynamo_db_client import DynamoDbClient diff --git a/metaflow/plugins/aws/step_functions/step_functions_deployer.py b/metaflow/plugins/aws/step_functions/step_functions_deployer.py index d9186e771cb..6e2f1d72151 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_deployer.py +++ b/metaflow/plugins/aws/step_functions/step_functions_deployer.py @@ -1,253 +1,94 @@ -import sys -import json -import tempfile -from typing import Optional, ClassVar, List - -from metaflow.plugins.aws.step_functions.step_functions import StepFunctions -from metaflow.runner.deployer import ( - DeployerImpl, - DeployedFlow, - TriggeredRun, - get_lower_level_group, - handle_timeout, -) - - -def terminate(instance: TriggeredRun, **kwargs): - """ - Terminate the running workflow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the terminate command. - - Returns - ------- - bool - True if the command was successful, False otherwise. - """ - _, run_id = instance.pathspec.split("/") +from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).terminate(run_id=run_id, **kwargs) +from metaflow.runner.deployer_impl import DeployerImpl - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - return command_obj.process.returncode == 0 - - -def production_token(instance: DeployedFlow): - """ - Get the production token for the deployed flow. - - Returns - ------- - str, optional - The production token, None if it cannot be retrieved. - """ - try: - _, production_token = StepFunctions.get_existing_deployment( - instance.deployer.name - ) - return production_token - except TypeError: - return None - - -def list_runs(instance: DeployedFlow, states: Optional[List[str]] = None): - """ - List runs of the deployed flow. - - Parameters - ---------- - states : Optional[List[str]], optional - A list of states to filter the runs by. Allowed values are: - RUNNING, SUCCEEDED, FAILED, TIMED_OUT, ABORTED. - If not provided, all states will be considered. - - Returns - ------- - List[TriggeredRun] - A list of TriggeredRun objects representing the runs of the deployed flow. - - Raises - ------ - ValueError - If any of the provided states are invalid or if there are duplicate states. - """ - VALID_STATES = {"RUNNING", "SUCCEEDED", "FAILED", "TIMED_OUT", "ABORTED"} - - if states is None: - states = [] - - unique_states = set(states) - if not unique_states.issubset(VALID_STATES): - invalid_states = unique_states - VALID_STATES - raise ValueError( - f"Invalid states found: {invalid_states}. Valid states are: {VALID_STATES}" - ) - - if len(states) != len(unique_states): - raise ValueError("Duplicate states are not allowed") - - triggered_runs = [] - executions = StepFunctions.list(instance.deployer.name, states) - - for e in executions: - run_id = "sfn-%s" % e["name"] - tr = TriggeredRun( - deployer=instance.deployer, - content=json.dumps( - { - "metadata": instance.deployer.metadata, - "pathspec": "/".join((instance.deployer.flow_name, run_id)), - "name": run_id, - } - ), - ) - tr._enrich_object({"terminate": terminate}) - triggered_runs.append(tr) - - return triggered_runs - - -def delete(instance: DeployedFlow, **kwargs): - """ - Delete the deployed flow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the delete command. - - Returns - ------- - bool - True if the command was successful, False otherwise. - """ - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).delete(**kwargs) - - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - return command_obj.process.returncode == 0 - - -def trigger(instance: DeployedFlow, **kwargs): - """ - Trigger a new run for the deployed flow. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to the trigger command, `Parameters` in particular - - Returns - ------- - StepFunctionsTriggeredRun - The triggered run instance. - - Raises - ------ - Exception - If there is an error during the trigger process. - """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) - - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - instance.deployer.api, - instance.deployer.top_level_kwargs, - instance.deployer.TYPE, - instance.deployer.deployer_kwargs, - ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) - - pid = instance.deployer.spm.run_command( - [sys.executable, *command], - env=instance.deployer.env_vars, - cwd=instance.deployer.cwd, - show_output=instance.deployer.show_output, - ) - - command_obj = instance.deployer.spm.get(pid) - content = handle_timeout( - tfp_runner_attribute, command_obj, instance.deployer.file_read_timeout - ) - - if command_obj.process.returncode == 0: - triggered_run = TriggeredRun(deployer=instance.deployer, content=content) - triggered_run._enrich_object({"terminate": terminate}) - return triggered_run - - raise Exception( - "Error triggering %s on %s for %s" - % (instance.deployer.name, instance.deployer.TYPE, instance.deployer.flow_file) - ) +if TYPE_CHECKING: + import metaflow.plugins.aws.step_functions.step_functions_deployer_objects class StepFunctionsDeployer(DeployerImpl): """ Deployer implementation for AWS Step Functions. - Attributes + Parameters ---------- - TYPE : ClassVar[Optional[str]] - The type of the deployer, which is "step-functions". + name : str, optional, default None + State Machine name. The flow name is used instead if this option is not specified. """ TYPE: ClassVar[Optional[str]] = "step-functions" - def __init__(self, deployer_kwargs, **kwargs): + def __init__(self, deployer_kwargs: Dict[str, str], **kwargs): """ Initialize the StepFunctionsDeployer. Parameters ---------- - deployer_kwargs : dict + deployer_kwargs : Dict[str, str] The deployer-specific keyword arguments. **kwargs : Any Additional arguments to pass to the superclass constructor. """ - self.deployer_kwargs = deployer_kwargs + self._deployer_kwargs = deployer_kwargs super().__init__(**kwargs) - def _enrich_deployed_flow(self, deployed_flow: DeployedFlow): + @property + def deployer_kwargs(self) -> Dict[str, Any]: + return self._deployer_kwargs + + @staticmethod + def deployed_flow_type() -> ( + Type[ + "metaflow.plugins.aws.step_functions.step_functions_deployer_objects.StepFunctionsDeployedFlow" + ] + ): + from .step_functions_deployer_objects import StepFunctionsDeployedFlow + + return StepFunctionsDeployedFlow + + def create( + self, **kwargs + ) -> "metaflow.plugins.aws.step_functions.step_functions_deployer_objects.StepFunctionsDeployedFlow": """ - Enrich the DeployedFlow object with additional properties and methods. + Create a new AWS Step Functions State Machine deployment. Parameters ---------- - deployed_flow : DeployedFlow - The deployed flow object to enrich. + authorize : str, optional, default None + Authorize using this production token. Required when re-deploying an existing flow + for the first time. The token is cached in METAFLOW_HOME. + generate_new_token : bool, optional, default False + Generate a new production token for this flow. Moves the production flow to a new namespace. + given_token : str, optional, default None + Use the given production token for this flow. Moves the production flow to the given namespace. + tags : List[str], optional, default None + Annotate all objects produced by AWS Step Functions runs with these tags. + user_namespace : str, optional, default None + Change the namespace from the default (production token) to the given tag. + only_json : bool, optional, default False + Only print out JSON sent to AWS Step Functions without deploying anything. + max_workers : int, optional, default 100 + Maximum number of parallel processes. + workflow_timeout : int, optional, default None + Workflow timeout in seconds. + log_execution_history : bool, optional, default False + Log AWS Step Functions execution history to AWS CloudWatch Logs log group. + use_distributed_map : bool, optional, default False + Use AWS Step Functions Distributed Map instead of Inline Map for defining foreach + tasks in Amazon State Language. + deployer_attribute_file : str, optional, default None + Write the workflow name to the specified file. Used internally for Metaflow's Deployer API. + + Returns + ------- + StepFunctionsDeployedFlow + The Flow deployed to AWS Step Functions. """ - deployed_flow._enrich_object( - { - "production_token": property(production_token), - "trigger": trigger, - "delete": delete, - "list_runs": list_runs, - } - ) + from .step_functions_deployer_objects import StepFunctionsDeployedFlow + + return self._create(StepFunctionsDeployedFlow, **kwargs) + + +_addl_stubgen_modules = [ + "metaflow.plugins.aws.step_functions.step_functions_deployer_objects" +] diff --git a/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py new file mode 100644 index 00000000000..9b3528af01b --- /dev/null +++ b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py @@ -0,0 +1,236 @@ +import sys +import json +import tempfile +from typing import ClassVar, Optional, List + +from metaflow.plugins.aws.step_functions.step_functions import StepFunctions +from metaflow.runner.deployer import DeployedFlow, TriggeredRun + +from metaflow.runner.utils import get_lower_level_group, handle_timeout + + +class StepFunctionsTriggeredRun(TriggeredRun): + """ + A class representing a triggered AWS Step Functions state machine execution. + """ + + def terminate(self, **kwargs) -> bool: + """ + Terminate the running state machine execution. + + Parameters + ---------- + authorize : str, optional, default None + Authorize the termination with a production token. + + Returns + ------- + bool + True if the command was successful, False otherwise. + """ + _, run_id = self.pathspec.split("/") + + # every subclass needs to have `self.deployer_kwargs` + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).terminate(run_id=run_id, **kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + return command_obj.process.returncode == 0 + + +class StepFunctionsDeployedFlow(DeployedFlow): + """ + A class representing a deployed AWS Step Functions state machine. + """ + + TYPE: ClassVar[Optional[str]] = "step-functions" + + @classmethod + def from_deployment(cls, identifier: str, metadata: Optional[str] = None): + """ + This method is not currently implemented for Step Functions. + + Raises + ------ + NotImplementedError + This method is not implemented for Step Functions. + """ + raise NotImplementedError( + "from_deployment is not implemented for StepFunctions" + ) + + @property + def production_token(self: DeployedFlow) -> Optional[str]: + """ + Get the production token for the deployed flow. + + Returns + ------- + str, optional + The production token, None if it cannot be retrieved. + """ + try: + _, production_token = StepFunctions.get_existing_deployment( + self.deployer.name + ) + return production_token + except TypeError: + return None + + def list_runs( + self, states: Optional[List[str]] = None + ) -> List[StepFunctionsTriggeredRun]: + """ + List runs of the deployed flow. + + Parameters + ---------- + states : List[str], optional, default None + A list of states to filter the runs by. Allowed values are: + RUNNING, SUCCEEDED, FAILED, TIMED_OUT, ABORTED. + If not provided, all states will be considered. + + Returns + ------- + List[StepFunctionsTriggeredRun] + A list of TriggeredRun objects representing the runs of the deployed flow. + + Raises + ------ + ValueError + If any of the provided states are invalid or if there are duplicate states. + """ + VALID_STATES = {"RUNNING", "SUCCEEDED", "FAILED", "TIMED_OUT", "ABORTED"} + + if states is None: + states = [] + + unique_states = set(states) + if not unique_states.issubset(VALID_STATES): + invalid_states = unique_states - VALID_STATES + raise ValueError( + f"Invalid states found: {invalid_states}. Valid states are: {VALID_STATES}" + ) + + if len(states) != len(unique_states): + raise ValueError("Duplicate states are not allowed") + + triggered_runs = [] + executions = StepFunctions.list(self.deployer.name, states) + + for e in executions: + run_id = "sfn-%s" % e["name"] + tr = StepFunctionsTriggeredRun( + deployer=self.deployer, + content=json.dumps( + { + "metadata": self.deployer.metadata, + "pathspec": "/".join((self.deployer.flow_name, run_id)), + "name": run_id, + } + ), + ) + triggered_runs.append(tr) + + return triggered_runs + + def delete(self, **kwargs) -> bool: + """ + Delete the deployed state machine. + + Parameters + ---------- + authorize : str, optional, default None + Authorize the deletion with a production token. + + Returns + ------- + bool + True if the command was successful, False otherwise. + """ + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).delete(**kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + return command_obj.process.returncode == 0 + + def trigger(self, **kwargs) -> StepFunctionsTriggeredRun: + """ + Trigger a new run for the deployed flow. + + Parameters + ---------- + **kwargs : Any + Additional arguments to pass to the trigger command, + `Parameters` in particular + + Returns + ------- + StepFunctionsTriggeredRun + The triggered run instance. + + Raises + ------ + Exception + If there is an error during the trigger process. + """ + with tempfile.TemporaryDirectory() as temp_dir: + tfp_runner_attribute = tempfile.NamedTemporaryFile( + dir=temp_dir, delete=False + ) + + # every subclass needs to have `self.deployer_kwargs` + command = get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + + command_obj = self.deployer.spm.get(pid) + content = handle_timeout( + tfp_runner_attribute, command_obj, self.deployer.file_read_timeout + ) + + if command_obj.process.returncode == 0: + return StepFunctionsTriggeredRun( + deployer=self.deployer, content=content + ) + + raise Exception( + "Error triggering %s on %s for %s" + % ( + self.deployer.name, + self.deployer.TYPE, + self.deployer.flow_file, + ) + ) diff --git a/metaflow/plugins/azure/includefile_support.py b/metaflow/plugins/azure/includefile_support.py index e1a36b32c56..1db2fd0ed98 100644 --- a/metaflow/plugins/azure/includefile_support.py +++ b/metaflow/plugins/azure/includefile_support.py @@ -8,6 +8,8 @@ class Azure(object): + TYPE = "azure" + @classmethod def get_root_from_config(cls, echo, create_on_absent=True): from metaflow.metaflow_config import DATATOOLS_AZUREROOT diff --git a/metaflow/plugins/cards/card_cli.py b/metaflow/plugins/cards/card_cli.py index ec4983ce13e..15f15f430fc 100644 --- a/metaflow/plugins/cards/card_cli.py +++ b/metaflow/plugins/cards/card_cli.py @@ -1,5 +1,6 @@ from metaflow.client import Task -from metaflow import JSONType, namespace +from metaflow.parameters import JSONTypeClass +from metaflow import namespace from metaflow.util import resolve_identity from metaflow.exception import ( CommandException, @@ -551,7 +552,7 @@ def _call(): "--options", default=None, show_default=True, - type=JSONType, + type=JSONTypeClass(), help="arguments of the card being created.", ) @click.option( diff --git a/metaflow/plugins/cards/card_modules/components.py b/metaflow/plugins/cards/card_modules/components.py index b5711bcd58a..9a0b5b37994 100644 --- a/metaflow/plugins/cards/card_modules/components.py +++ b/metaflow/plugins/cards/card_modules/components.py @@ -712,15 +712,15 @@ class ProgressBar(UserComponent): Parameters ---------- - max : int + max : int, default 100 The maximum value of the progress bar. - label : str, optional + label : str, optional, default None Optional label for the progress bar. - value : int, optional + value : int, default 0 Optional initial value of the progress bar. - unit : str, optional + unit : str, optional, default None Optional unit for the progress bar. - metadata : str, optional + metadata : str, optional, default None Optional additional information to show on the progress bar. """ @@ -731,10 +731,10 @@ class ProgressBar(UserComponent): def __init__( self, max: int = 100, - label: str = None, + label: Optional[str] = None, value: int = 0, - unit: str = None, - metadata: str = None, + unit: Optional[str] = None, + metadata: Optional[str] = None, ): self._label = label self._max = max @@ -742,7 +742,7 @@ def __init__( self._unit = unit self._metadata = metadata - def update(self, new_value: int, metadata: str = None): + def update(self, new_value: int, metadata: Optional[str] = None): self._value = new_value if metadata is not None: self._metadata = metadata diff --git a/metaflow/plugins/datatools/local.py b/metaflow/plugins/datatools/local.py index f326f6e0412..4c12e842a30 100644 --- a/metaflow/plugins/datatools/local.py +++ b/metaflow/plugins/datatools/local.py @@ -81,6 +81,8 @@ class Local(object): In the future, we may want to allow it to be used in a way similar to the S3() client. """ + TYPE = "local" + @staticmethod def _makedirs(path): try: diff --git a/metaflow/plugins/datatools/s3/s3.py b/metaflow/plugins/datatools/s3/s3.py index aeffbabcd11..0f12c199a18 100644 --- a/metaflow/plugins/datatools/s3/s3.py +++ b/metaflow/plugins/datatools/s3/s3.py @@ -504,6 +504,8 @@ class S3(object): If `run` is not specified, use this as the S3 prefix. """ + TYPE = "s3" + @classmethod def get_root_from_config(cls, echo, create_on_absent=True): return DATATOOLS_S3ROOT diff --git a/metaflow/plugins/gcp/includefile_support.py b/metaflow/plugins/gcp/includefile_support.py index 35d9aab96dc..de84aafb9bc 100644 --- a/metaflow/plugins/gcp/includefile_support.py +++ b/metaflow/plugins/gcp/includefile_support.py @@ -8,6 +8,9 @@ class GS(object): + + TYPE = "gs" + @classmethod def get_root_from_config(cls, echo, create_on_absent=True): from metaflow.metaflow_config import DATATOOLS_GSROOT diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 3b5035d1f59..c0d729f161f 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -9,7 +9,7 @@ from metaflow import JSONTypeClass, util from metaflow._vendor import click from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, MetaflowException -from metaflow.metadata.util import sync_local_metadata_from_datastore +from metaflow.metadata_provider.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS from metaflow.mflog import TASK_LOG_SOURCE from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 9213b658879..7852685bdcc 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -7,8 +7,8 @@ from metaflow import current from metaflow.decorators import StepDecorator from metaflow.exception import MetaflowException -from metaflow.metadata import MetaDatum -from metaflow.metadata.util import sync_local_metadata_to_datastore +from metaflow.metadata_provider import MetaDatum +from metaflow.metadata_provider.util import sync_local_metadata_to_datastore from metaflow.metaflow_config import ( DATASTORE_LOCAL_DIR, KUBERNETES_CONTAINER_IMAGE, @@ -73,8 +73,9 @@ class KubernetesDecorator(StepDecorator): in Metaflow configuration. node_selector: Union[Dict[str,str], str], optional, default None Kubernetes node selector(s) to apply to the pod running the task. - Can be passed in as a comma separated string of values e.g. "kubernetes.io/os=linux,kubernetes.io/arch=amd64" - or as a dictionary {"kubernetes.io/os": "linux", "kubernetes.io/arch": "amd64"} + Can be passed in as a comma separated string of values e.g. + 'kubernetes.io/os=linux,kubernetes.io/arch=amd64' or as a dictionary + {'kubernetes.io/os': 'linux', 'kubernetes.io/arch': 'amd64'} namespace : str, default METAFLOW_KUBERNETES_NAMESPACE Kubernetes namespace to use when launching pod in Kubernetes. gpu : int, optional, default None diff --git a/metaflow/plugins/metadata/__init__.py b/metaflow/plugins/metadata_providers/__init__.py similarity index 100% rename from metaflow/plugins/metadata/__init__.py rename to metaflow/plugins/metadata_providers/__init__.py diff --git a/metaflow/plugins/metadata/local.py b/metaflow/plugins/metadata_providers/local.py similarity index 99% rename from metaflow/plugins/metadata/local.py rename to metaflow/plugins/metadata_providers/local.py index 792572219f7..ea7754cac5f 100644 --- a/metaflow/plugins/metadata/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -8,9 +8,9 @@ from collections import namedtuple from metaflow.exception import MetaflowInternalError, MetaflowTaggingError -from metaflow.metadata.metadata import ObjectOrder +from metaflow.metadata_provider.metadata import ObjectOrder from metaflow.metaflow_config import DATASTORE_LOCAL_DIR -from metaflow.metadata import MetadataProvider +from metaflow.metadata_provider import MetadataProvider from metaflow.tagging_util import MAX_USER_TAG_SET_SIZE, validate_tags diff --git a/metaflow/plugins/metadata/service.py b/metaflow/plugins/metadata_providers/service.py similarity index 99% rename from metaflow/plugins/metadata/service.py rename to metaflow/plugins/metadata_providers/service.py index bc6b82c9ac3..2e69026deb1 100644 --- a/metaflow/plugins/metadata/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -14,8 +14,8 @@ SERVICE_HEADERS, SERVICE_URL, ) -from metaflow.metadata import MetadataProvider -from metaflow.metadata.heartbeat import HB_URL_KEY +from metaflow.metadata_provider import MetadataProvider +from metaflow.metadata_provider.heartbeat import HB_URL_KEY from metaflow.sidecar import Message, MessageTypes, Sidecar from metaflow.util import version_parse diff --git a/metaflow/plugins/parallel_decorator.py b/metaflow/plugins/parallel_decorator.py index 0a137e9e47a..13492047c2f 100644 --- a/metaflow/plugins/parallel_decorator.py +++ b/metaflow/plugins/parallel_decorator.py @@ -2,7 +2,7 @@ from metaflow.decorators import StepDecorator from metaflow.unbounded_foreach import UBF_CONTROL, CONTROL_TASK_TAG from metaflow.exception import MetaflowException -from metaflow.metadata import MetaDatum +from metaflow.metadata_provider import MetaDatum from metaflow.metaflow_current import current, Parallel import os import sys diff --git a/metaflow/plugins/pypi/conda_decorator.py b/metaflow/plugins/pypi/conda_decorator.py index 9ba663276a2..74418ae9f54 100644 --- a/metaflow/plugins/pypi/conda_decorator.py +++ b/metaflow/plugins/pypi/conda_decorator.py @@ -8,7 +8,7 @@ from metaflow.decorators import FlowDecorator, StepDecorator from metaflow.extension_support import EXT_PKG -from metaflow.metadata import MetaDatum +from metaflow.metadata_provider import MetaDatum from metaflow.metaflow_environment import InvalidEnvironmentException from metaflow.util import get_metaflow_root diff --git a/metaflow/plugins/test_unbounded_foreach_decorator.py b/metaflow/plugins/test_unbounded_foreach_decorator.py index a38d7cb3b6e..5e15cb05501 100644 --- a/metaflow/plugins/test_unbounded_foreach_decorator.py +++ b/metaflow/plugins/test_unbounded_foreach_decorator.py @@ -15,7 +15,7 @@ CONTROL_TASK_TAG, ) from metaflow.util import to_unicode -from metaflow.metadata import MetaDatum +from metaflow.metadata_provider import MetaDatum class InternalTestUnboundedForeachInput(UnboundedForeachInput): diff --git a/metaflow/runner/click_api.py b/metaflow/runner/click_api.py index ed26a8acc08..6bc1fc9b691 100644 --- a/metaflow/runner/click_api.py +++ b/metaflow/runner/click_api.py @@ -193,6 +193,10 @@ def parent(self): def chain(self): return self._chain + @property + def name(self): + return self._API_NAME + @classmethod def from_cli(cls, flow_file: str, cli_collection: Callable) -> Callable: flow_cls = extract_flow_class_from_file(flow_file) diff --git a/metaflow/runner/deployer.py b/metaflow/runner/deployer.py index d2019856209..e49761b3f9e 100644 --- a/metaflow/runner/deployer.py +++ b/metaflow/runner/deployer.py @@ -1,53 +1,50 @@ -import os -import sys import json import time -import importlib -import functools -import tempfile -from typing import Optional, Dict, ClassVar +from typing import ClassVar, Dict, Optional, TYPE_CHECKING from metaflow.exception import MetaflowNotFound -from metaflow.runner.subprocess_manager import SubprocessManager -from metaflow.runner.utils import handle_timeout +from metaflow.metaflow_config import DEFAULT_FROM_DEPLOYMENT_IMPL +if TYPE_CHECKING: + import metaflow + import metaflow.runner.deployer_impl -def get_lower_level_group( - api, top_level_kwargs: Dict, _type: Optional[str], deployer_kwargs: Dict -): - """ - Retrieve a lower-level group from the API based on the type and provided arguments. - Parameters - ---------- - api : MetaflowAPI - Metaflow API instance. - top_level_kwargs : Dict - Top-level keyword arguments to pass to the API. - _type : str - Type of the deployer implementation to target. - deployer_kwargs : Dict - Keyword arguments specific to the deployer. - - Returns - ------- - Any - The lower-level group object retrieved from the API. - - Raises - ------ - ValueError - If the `_type` is None. - """ - if _type is None: - raise ValueError( - "DeployerImpl doesn't have a 'TYPE' to target. Please use a sub-class of DeployerImpl." - ) - return getattr(api(**top_level_kwargs), _type)(**deployer_kwargs) +class DeployerMeta(type): + def __new__(mcs, name, bases, dct): + cls = super().__new__(mcs, name, bases, dct) + + from metaflow.plugins import DEPLOYER_IMPL_PROVIDERS + + def _injected_method(method_name, deployer_class): + def f(self, **deployer_kwargs): + return deployer_class( + deployer_kwargs=deployer_kwargs, + flow_file=self.flow_file, + show_output=self.show_output, + profile=self.profile, + env=self.env, + cwd=self.cwd, + file_read_timeout=self.file_read_timeout, + **self.top_level_kwargs, + ) + + f.__doc__ = provider_class.__doc__ or "" + f.__name__ = method_name + return f + + for provider_class in DEPLOYER_IMPL_PROVIDERS: + # TYPE is the name of the CLI groups i.e. + # `argo-workflows` instead of `argo_workflows` + # The injected method names replace '-' by '_' though. + method_name = provider_class.TYPE.replace("-", "_") + setattr(cls, method_name, _injected_method(method_name, provider_class)) + + return cls -class Deployer(object): +class Deployer(metaclass=DeployerMeta): """ Use the `Deployer` class to configure and access one of the production orchestrators supported by Metaflow. @@ -81,7 +78,7 @@ def __init__( env: Optional[Dict] = None, cwd: Optional[str] = None, file_read_timeout: int = 3600, - **kwargs + **kwargs, ): self.flow_file = flow_file self.show_output = show_output @@ -91,56 +88,16 @@ def __init__( self.file_read_timeout = file_read_timeout self.top_level_kwargs = kwargs - from metaflow.plugins import DEPLOYER_IMPL_PROVIDERS - - for provider_class in DEPLOYER_IMPL_PROVIDERS: - # TYPE is the name of the CLI groups i.e. - # `argo-workflows` instead of `argo_workflows` - # The injected method names replace '-' by '_' though. - method_name = provider_class.TYPE.replace("-", "_") - setattr(Deployer, method_name, self.__make_function(provider_class)) - - def __make_function(self, deployer_class): - """ - Create a function for the given deployer class. - - Parameters - ---------- - deployer_class : Type[DeployerImpl] - Deployer implementation class. - - Returns - ------- - Callable - Function that initializes and returns an instance of the deployer class. - """ - - def f(self, **deployer_kwargs): - return deployer_class( - deployer_kwargs=deployer_kwargs, - flow_file=self.flow_file, - show_output=self.show_output, - profile=self.profile, - env=self.env, - cwd=self.cwd, - file_read_timeout=self.file_read_timeout, - **self.top_level_kwargs - ) - - return f - class TriggeredRun(object): """ - TriggeredRun class represents a run that has been triggered on a production orchestrator. - - Only when the `start` task starts running, the `run` object corresponding to the run - becomes available. + TriggeredRun class represents a run that has been triggered on a + production orchestrator. """ def __init__( self, - deployer: "DeployerImpl", + deployer: "metaflow.runner.deployer_impl.DeployerImpl", content: str, ): self.deployer = deployer @@ -149,31 +106,18 @@ def __init__( self.pathspec = content_json.get("pathspec") self.name = content_json.get("name") - def _enrich_object(self, env): - """ - Enrich the TriggeredRun object with additional properties and methods. - - Parameters - ---------- - env : dict - Environment dictionary containing properties and methods to add. - """ - for k, v in env.items(): - if isinstance(v, property): - setattr(self.__class__, k, v) - elif callable(v): - setattr(self, k, functools.partial(v, self)) - else: - setattr(self, k, v) - - def wait_for_run(self, timeout=None): + def wait_for_run(self, timeout: Optional[int] = None): """ Wait for the `run` property to become available. + The `run` property becomes available only after the `start` task of the triggered + flow starts running. + Parameters ---------- - timeout : int, optional - Maximum time to wait for the `run` to become available, in seconds. If None, wait indefinitely. + timeout : int, optional, default None + Maximum time to wait for the `run` to become available, in seconds. If + None, wait indefinitely. Raises ------ @@ -194,7 +138,7 @@ def wait_for_run(self, timeout=None): time.sleep(check_interval) @property - def run(self): + def run(self) -> Optional["metaflow.Run"]: """ Retrieve the `Run` object for the triggered run. @@ -214,178 +158,104 @@ def run(self): return None -class DeployedFlow(object): - """ - DeployedFlow class represents a flow that has been deployed. +class DeployedFlowMeta(type): + def __new__(mcs, name, bases, dct): + cls = super().__new__(mcs, name, bases, dct) + if not bases: + # Inject methods only in DeployedFlow and not any of its + # subclasses + from metaflow.plugins import DEPLOYER_IMPL_PROVIDERS - Parameters - ---------- - deployer : DeployerImpl - Instance of the deployer implementation. - """ - - def __init__(self, deployer: "DeployerImpl"): - self.deployer = deployer + allowed_providers = dict( + { + provider.TYPE.replace("-", "_"): provider + for provider in DEPLOYER_IMPL_PROVIDERS + } + ) - def _enrich_object(self, env): - """ - Enrich the DeployedFlow object with additional properties and methods. + def _default_injected_method(): + def f( + cls, + identifier: str, + metadata: Optional[str] = None, + impl: str = DEFAULT_FROM_DEPLOYMENT_IMPL.replace("-", "_"), + ) -> "DeployedFlow": + """ + Retrieves a `DeployedFlow` object from an identifier and optional + metadata. The `impl` parameter specifies the deployer implementation + to use (like `argo-workflows`). + + Parameters + ---------- + identifier : str + Deployer specific identifier for the workflow to retrieve + metadata : str, optional, default None + Optional deployer specific metadata. + impl : str, optional, default given by METAFLOW_DEFAULT_FROM_DEPLOYMENT_IMPL + The default implementation to use if not specified + + Returns + ------- + DeployedFlow + A `DeployedFlow` object representing the deployed flow corresponding + to the identifier + """ + if impl in allowed_providers: + return ( + allowed_providers[impl] + .deployed_flow_type() + .from_deployment(identifier, metadata) + ) + else: + raise ValueError( + f"No deployer '{impl}' exists; valid deployers are: " + f"{list(allowed_providers.keys())}" + ) + + f.__name__ = "from_deployment" + return f + + def _per_type_injected_method(method_name, impl): + def f( + cls, + identifier: str, + metadata: Optional[str] = None, + ): + return ( + allowed_providers[impl] + .deployed_flow_type() + .from_deployment(identifier, metadata) + ) + + f.__name__ = method_name + return f + + setattr(cls, "from_deployment", classmethod(_default_injected_method())) + + for impl in allowed_providers: + method_name = f"from_{impl}" + setattr( + cls, + method_name, + classmethod(_per_type_injected_method(method_name, impl)), + ) - Parameters - ---------- - env : dict - Environment dictionary containing properties and methods to add. - """ - for k, v in env.items(): - if isinstance(v, property): - setattr(self.__class__, k, v) - elif callable(v): - setattr(self, k, functools.partial(v, self)) - else: - setattr(self, k, v) + return cls -class DeployerImpl(object): +class DeployedFlow(metaclass=DeployedFlowMeta): """ - Base class for deployer implementations. Each implementation should define a TYPE - class variable that matches the name of the CLI group. + DeployedFlow class represents a flow that has been deployed. - Parameters - ---------- - flow_file : str - Path to the flow file to deploy. - show_output : bool, default True - Show the 'stdout' and 'stderr' to the console by default. - profile : Optional[str], default None - Metaflow profile to use for the deployment. If not specified, the default - profile is used. - env : Optional[Dict], default None - Additional environment variables to set for the deployment. - cwd : Optional[str], default None - The directory to run the subprocess in; if not specified, the current - directory is used. - file_read_timeout : int, default 3600 - The timeout until which we try to read the deployer attribute file. - **kwargs : Any - Additional arguments that you would pass to `python myflow.py` before - the deployment command. + This class is not meant to be instantiated directly. Instead, it is returned from + methods of `Deployer`. """ + # This should match the TYPE value in DeployerImpl for proper stub generation TYPE: ClassVar[Optional[str]] = None - def __init__( - self, - flow_file: str, - show_output: bool = True, - profile: Optional[str] = None, - env: Optional[Dict] = None, - cwd: Optional[str] = None, - file_read_timeout: int = 3600, - **kwargs - ): - if self.TYPE is None: - raise ValueError( - "DeployerImpl doesn't have a 'TYPE' to target. Please use a sub-class of DeployerImpl." - ) - - if "metaflow.cli" in sys.modules: - importlib.reload(sys.modules["metaflow.cli"]) - from metaflow.cli import start - from metaflow.runner.click_api import MetaflowAPI - - self.flow_file = flow_file - self.show_output = show_output - self.profile = profile - self.env = env - self.cwd = cwd - self.file_read_timeout = file_read_timeout - - self.env_vars = os.environ.copy() - self.env_vars.update(self.env or {}) - if self.profile: - self.env_vars["METAFLOW_PROFILE"] = profile - - self.spm = SubprocessManager() - self.top_level_kwargs = kwargs - self.api = MetaflowAPI.from_cli(self.flow_file, start) - - def __enter__(self) -> "DeployerImpl": - return self - - def create(self, **kwargs) -> DeployedFlow: - """ - Create a deployed flow using the deployer implementation. - - Parameters - ---------- - **kwargs : Any - Additional arguments to pass to `create` corresponding to the - command line arguments of `create` - - Returns - ------- - DeployedFlow - DeployedFlow object representing the deployed flow. - - Raises - ------ - Exception - If there is an error during deployment. - """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs - ).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) - - pid = self.spm.run_command( - [sys.executable, *command], - env=self.env_vars, - cwd=self.cwd, - show_output=self.show_output, - ) - - command_obj = self.spm.get(pid) - content = handle_timeout( - tfp_runner_attribute, command_obj, self.file_read_timeout - ) - content = json.loads(content) - self.name = content.get("name") - self.flow_name = content.get("flow_name") - self.metadata = content.get("metadata") - # Additional info is used to pass additional deployer specific information. - # It is used in non-OSS deployers (extensions). - self.additional_info = content.get("additional_info", {}) - - if command_obj.process.returncode == 0: - deployed_flow = DeployedFlow(deployer=self) - self._enrich_deployed_flow(deployed_flow) - return deployed_flow - - raise Exception("Error deploying %s to %s" % (self.flow_file, self.TYPE)) - - def _enrich_deployed_flow(self, deployed_flow: DeployedFlow): - """ - Enrich the DeployedFlow object with additional properties and methods. - - Parameters - ---------- - deployed_flow : DeployedFlow - The DeployedFlow object to enrich. - """ - raise NotImplementedError - - def __exit__(self, exc_type, exc_value, traceback): - """ - Cleanup resources on exit. - """ - self.cleanup() - - def cleanup(self): - """ - Cleanup resources. - """ - self.spm.cleanup() + def __init__(self, deployer: "metaflow.runner.deployer_impl.DeployerImpl"): + self.deployer = deployer + self.name = self.deployer.name + self.flow_name = self.deployer.flow_name + self.metadata = self.deployer.metadata diff --git a/metaflow/runner/deployer_impl.py b/metaflow/runner/deployer_impl.py new file mode 100644 index 00000000000..07e6cf51429 --- /dev/null +++ b/metaflow/runner/deployer_impl.py @@ -0,0 +1,167 @@ +import importlib +import json +import os +import sys +import tempfile + +from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type + +from .subprocess_manager import SubprocessManager +from .utils import get_lower_level_group, handle_timeout + +if TYPE_CHECKING: + import metaflow.runner.deployer + +# NOTE: This file is separate from the deployer.py file to prevent circular imports. +# This file is needed in any of the DeployerImpl implementations +# (like argo_workflows_deployer.py) which is in turn needed to create the Deployer +# class (ie: it uses ArgoWorkflowsDeployer to create the Deployer class). + + +class DeployerImpl(object): + """ + Base class for deployer implementations. Each implementation should define a TYPE + class variable that matches the name of the CLI group. + + Parameters + ---------- + flow_file : str + Path to the flow file to deploy. + show_output : bool, default True + Show the 'stdout' and 'stderr' to the console by default. + profile : Optional[str], default None + Metaflow profile to use for the deployment. If not specified, the default + profile is used. + env : Optional[Dict], default None + Additional environment variables to set for the deployment. + cwd : Optional[str], default None + The directory to run the subprocess in; if not specified, the current + directory is used. + file_read_timeout : int, default 3600 + The timeout until which we try to read the deployer attribute file. + **kwargs : Any + Additional arguments that you would pass to `python myflow.py` before + the deployment command. + """ + + TYPE: ClassVar[Optional[str]] = None + + def __init__( + self, + flow_file: str, + show_output: bool = True, + profile: Optional[str] = None, + env: Optional[Dict] = None, + cwd: Optional[str] = None, + file_read_timeout: int = 3600, + **kwargs + ): + if self.TYPE is None: + raise ValueError( + "DeployerImpl doesn't have a 'TYPE' to target. Please use a sub-class " + "of DeployerImpl." + ) + + if "metaflow.cli" in sys.modules: + importlib.reload(sys.modules["metaflow.cli"]) + from metaflow.cli import start + from metaflow.runner.click_api import MetaflowAPI + + self.flow_file = flow_file + self.show_output = show_output + self.profile = profile + self.env = env + self.cwd = cwd + self.file_read_timeout = file_read_timeout + + self.env_vars = os.environ.copy() + self.env_vars.update(self.env or {}) + if self.profile: + self.env_vars["METAFLOW_PROFILE"] = profile + + self.spm = SubprocessManager() + self.top_level_kwargs = kwargs + self.api = MetaflowAPI.from_cli(self.flow_file, start) + + @property + def deployer_kwargs(self) -> Dict[str, Any]: + raise NotImplementedError + + @staticmethod + def deployed_flow_type() -> Type["metaflow.runner.deployer.DeployedFlow"]: + raise NotImplementedError + + def __enter__(self) -> "DeployerImpl": + return self + + def create(self, **kwargs) -> "metaflow.runner.deployer.DeployedFlow": + """ + Create a sub-class of a `DeployedFlow` depending on the deployer implementation. + + Parameters + ---------- + **kwargs : Any + Additional arguments to pass to `create` corresponding to the + command line arguments of `create` + + Returns + ------- + DeployedFlow + DeployedFlow object representing the deployed flow. + + Raises + ------ + Exception + If there is an error during deployment. + """ + # Sub-classes should implement this by simply calling _create and pass the + # proper class as the DeployedFlow to return. + raise NotImplementedError + + def _create( + self, create_class: Type["metaflow.runner.deployer.DeployedFlow"], **kwargs + ) -> "metaflow.runner.deployer.DeployedFlow": + with tempfile.TemporaryDirectory() as temp_dir: + tfp_runner_attribute = tempfile.NamedTemporaryFile( + dir=temp_dir, delete=False + ) + # every subclass needs to have `self.deployer_kwargs` + command = get_lower_level_group( + self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs + ).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + + command_obj = self.spm.get(pid) + content = handle_timeout( + tfp_runner_attribute, command_obj, self.file_read_timeout + ) + content = json.loads(content) + self.name = content.get("name") + self.flow_name = content.get("flow_name") + self.metadata = content.get("metadata") + # Additional info is used to pass additional deployer specific information. + # It is used in non-OSS deployers (extensions). + self.additional_info = content.get("additional_info", {}) + + if command_obj.process.returncode == 0: + return create_class(deployer=self) + + raise RuntimeError("Error deploying %s to %s" % (self.flow_file, self.TYPE)) + + def __exit__(self, exc_type, exc_value, traceback): + """ + Cleanup resources on exit. + """ + self.cleanup() + + def cleanup(self): + """ + Cleanup resources. + """ + self.spm.cleanup() diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index badc4cd35a1..78418f49e8a 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -67,10 +67,11 @@ async def wait( Parameters ---------- - timeout : Optional[float], default None - The maximum time to wait for the run to finish. - If the timeout is reached, the run is terminated - stream : Optional[str], default None + timeout : float, optional, default None + The maximum time, in seconds, to wait for the run to finish. + If the timeout is reached, the run is terminated. If not specified, wait + forever. + stream : str, optional, default None If specified, the specified stream is printed to stdout. `stream` can be one of `stdout` or `stderr`. @@ -167,7 +168,7 @@ async def stream_log( ---------- stream : str The stream to stream logs from. Can be one of `stdout` or `stderr`. - position : Optional[int], default None + position : int, optional, default None The position in the log file to start streaming from. If None, it starts from the beginning of the log file. This allows resuming streaming from a previously known position @@ -207,13 +208,13 @@ class Runner(object): show_output : bool, default True Show the 'stdout' and 'stderr' to the console by default, Only applicable for synchronous 'run' and 'resume' functions. - profile : Optional[str], default None + profile : str, optional, default None Metaflow profile to use to run this run. If not specified, the default profile is used (or the one already set using `METAFLOW_PROFILE`) - env : Optional[Dict], default None + env : Dict[str, str], optional, default None Additional environment variables to set for the Run. This overrides the environment set for this process. - cwd : Optional[str], default None + cwd : str, optional, default None The directory to run the subprocess in; if not specified, the current directory is used. file_read_timeout : int, default 3600 @@ -228,7 +229,7 @@ def __init__( flow_file: str, show_output: bool = True, profile: Optional[str] = None, - env: Optional[Dict] = None, + env: Optional[Dict[str, str]] = None, cwd: Optional[str] = None, file_read_timeout: int = 3600, **kwargs diff --git a/metaflow/runner/nbdeploy.py b/metaflow/runner/nbdeploy.py index dea8c0f41cf..355047d2a46 100644 --- a/metaflow/runner/nbdeploy.py +++ b/metaflow/runner/nbdeploy.py @@ -37,13 +37,13 @@ class NBDeployer(object): Flow defined in the same cell show_output : bool, default True Show the 'stdout' and 'stderr' to the console by default, - profile : Optional[str], default None + profile : str, optional, default None Metaflow profile to use to deploy this run. If not specified, the default profile is used (or the one already set using `METAFLOW_PROFILE`) - env : Optional[Dict[str, str]], default None + env : Dict[str, str], optional, default None Additional environment variables to set. This overrides the environment set for this process. - base_dir : Optional[str], default None + base_dir : str, optional, default None The directory to run the subprocess in; if not specified, the current working directory is used. **kwargs : Any @@ -66,10 +66,11 @@ def __init__( from IPython import get_ipython ipython = get_ipython() - except ModuleNotFoundError: + except ModuleNotFoundError as e: raise NBDeployerInitializationError( - "'NBDeployer' requires an interactive Python environment (such as Jupyter)" - ) + "'NBDeployer' requires an interactive Python environment " + "(such as Jupyter)" + ) from e self.cell = get_current_cell(ipython) self.flow = flow @@ -116,13 +117,11 @@ def __init__( **kwargs, ) - from metaflow.plugins import DEPLOYER_IMPL_PROVIDERS - - for provider_class in DEPLOYER_IMPL_PROVIDERS: - method_name = provider_class.TYPE.replace("-", "_") - setattr( - NBDeployer, method_name, self.deployer.__make_function(provider_class) - ) + def __getattr__(self, name): + """ + Forward all attribute access to the underlying `Deployer` instance. + """ + return getattr(self.deployer, name) def cleanup(self): """ diff --git a/metaflow/runner/nbrun.py b/metaflow/runner/nbrun.py index 400a77a85b8..10fc473eff3 100644 --- a/metaflow/runner/nbrun.py +++ b/metaflow/runner/nbrun.py @@ -34,13 +34,13 @@ class NBRunner(object): show_output : bool, default True Show the 'stdout' and 'stderr' to the console by default, Only applicable for synchronous 'run' and 'resume' functions. - profile : Optional[str], default None + profile : str, optional, default None Metaflow profile to use to run this run. If not specified, the default profile is used (or the one already set using `METAFLOW_PROFILE`) - env : Optional[Dict], default None + env : Dict[str, str], optional, default None Additional environment variables to set for the Run. This overrides the environment set for this process. - base_dir : Optional[str], default None + base_dir : str, optional, default None The directory to run the subprocess in; if not specified, the current working directory is used. file_read_timeout : int, default 3600 diff --git a/metaflow/runner/utils.py b/metaflow/runner/utils.py index eaac78bef78..0ef202a3a5a 100644 --- a/metaflow/runner/utils.py +++ b/metaflow/runner/utils.py @@ -4,10 +4,12 @@ import asyncio from subprocess import CalledProcessError -from typing import Dict, TYPE_CHECKING +from typing import Any, Dict, TYPE_CHECKING if TYPE_CHECKING: - from .subprocess_manager import CommandManager + import tempfile + import metaflow.runner.subprocess_manager + import metaflow.runner.click_api def get_current_cell(ipython): @@ -18,7 +20,8 @@ def get_current_cell(ipython): def format_flowfile(cell): """ - Formats the given cell content to create a valid Python script that can be executed as a Metaflow flow. + Formats the given cell content to create a valid Python script that can be + executed as a Metaflow flow. """ flowspec = [ x @@ -36,7 +39,9 @@ def format_flowfile(cell): return "\n".join(lines) -def check_process_status(command_obj: "CommandManager"): +def check_process_status( + command_obj: "metaflow.runner.subprocess_manager.CommandManager", +): if isinstance(command_obj.process, asyncio.subprocess.Process): return command_obj.process.returncode is not None else: @@ -44,7 +49,9 @@ def check_process_status(command_obj: "CommandManager"): def read_from_file_when_ready( - file_path: str, command_obj: "CommandManager", timeout: float = 5 + file_path: str, + command_obj: "metaflow.runner.subprocess_manager.CommandManager", + timeout: float = 5, ): start_time = time.time() with open(file_path, "r", encoding="utf-8") as file_pointer: @@ -70,7 +77,9 @@ def read_from_file_when_ready( def handle_timeout( - tfp_runner_attribute, command_obj: "CommandManager", file_read_timeout: int + tfp_runner_attribute: "tempfile._TemporaryFileWrapper[str]", + command_obj: "metaflow.runner.subprocess_manager.CommandManager", + file_read_timeout: int, ): """ Handle the timeout for a running subprocess command that reads a file @@ -102,8 +111,8 @@ def handle_timeout( ) return content except (CalledProcessError, TimeoutError) as e: - stdout_log = open(command_obj.log_files["stdout"]).read() - stderr_log = open(command_obj.log_files["stderr"]).read() + stdout_log = open(command_obj.log_files["stdout"], encoding="utf-8").read() + stderr_log = open(command_obj.log_files["stderr"], encoding="utf-8").read() command = " ".join(command_obj.command) error_message = "Error executing: '%s':\n" % command if stdout_log.strip(): @@ -111,3 +120,41 @@ def handle_timeout( if stderr_log.strip(): error_message += "\nStderr:\n%s\n" % stderr_log raise RuntimeError(error_message) from e + + +def get_lower_level_group( + api: "metaflow.runner.click_api.MetaflowAPI", + top_level_kwargs: Dict[str, Any], + sub_command: str, + sub_command_kwargs: Dict[str, Any], +) -> "metaflow.runner.click_api.MetaflowAPI": + """ + Retrieve a lower-level group from the API based on the type and provided arguments. + + Parameters + ---------- + api : MetaflowAPI + Metaflow API instance. + top_level_kwargs : Dict[str, Any] + Top-level keyword arguments to pass to the API. + sub_command : str + Sub-command of API to get the API for + sub_command_kwargs : Dict[str, Any] + Sub-command arguments + + Returns + ------- + MetaflowAPI + The lower-level group object retrieved from the API. + + Raises + ------ + ValueError + If the `_type` is None. + """ + sub_command_obj = getattr(api(**top_level_kwargs), sub_command) + + if sub_command_obj is None: + raise ValueError(f"Sub-command '{sub_command}' not found in API '{api.name}'") + + return sub_command_obj(**sub_command_kwargs) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index d5fbc0b6837..5f86210aa67 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -20,7 +20,7 @@ from contextlib import contextmanager from . import get_namespace -from .metadata import MetaDatum +from .metadata_provider import MetaDatum from .metaflow_config import MAX_ATTEMPTS, UI_URL from .exception import ( MetaflowException, diff --git a/metaflow/task.py b/metaflow/task.py index bba15c45471..6b73302652b 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -12,7 +12,7 @@ from metaflow.datastore.exceptions import DataException from .metaflow_config import MAX_ATTEMPTS -from .metadata import MetaDatum +from .metadata_provider import MetaDatum from .mflog import TASK_LOG_SOURCE from .datastore import Inputs, TaskDataStoreSet from .exception import ( diff --git a/stubs/test/test_stubs.yml b/stubs/test/test_stubs.yml index 044e0ff0bf0..7ef1ec26666 100644 --- a/stubs/test/test_stubs.yml +++ b/stubs/test/test_stubs.yml @@ -7,8 +7,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, project @@ -52,8 +51,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, pypi_base @@ -101,8 +99,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, conda_base @@ -150,8 +147,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, schedule @@ -199,8 +195,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, trigger @@ -248,8 +243,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, trigger_on_finish @@ -297,8 +291,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step @@ -362,8 +355,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, retry, catch @@ -401,8 +393,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, batch @@ -462,8 +453,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, kubernetes @@ -519,8 +509,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, environment @@ -580,8 +569,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, card @@ -641,8 +629,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, catch @@ -702,8 +689,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, pypi @@ -763,8 +749,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, conda @@ -824,8 +809,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, resources @@ -885,8 +869,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, retry @@ -946,8 +929,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, secrets @@ -1007,8 +989,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, timeout @@ -1067,8 +1048,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import Flow, Run, Step, Task, DataArtifact @@ -1082,7 +1062,7 @@ Run("flow_name/run_id")[0] Step("flow_name/run_id/step_name")[0] out: | - main:3: note: Revealed type is "metaflow.Flow" + main:3: note: Revealed type is "metaflow.client.core.Flow" main:4: note: Revealed type is "metaflow.client.core.Run" main:5: note: Revealed type is "metaflow.client.core.Step" main:6: note: Revealed type is "metaflow.client.core.Task" @@ -1099,8 +1079,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import current @@ -1121,8 +1100,7 @@ - python_version: "3.10" - python_version: "3.11" - python_version: "3.12" - mypy_config: - python_version = {{ python_version }} + mypy_config: python_version = {{ python_version }} main: | from metaflow import FlowSpec, step, batch, project, schedule From 0bc4a9683ba67eedd756a8dc777916020587d5f7 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 7 Nov 2024 13:28:49 -0800 Subject: [PATCH 11/22] Prepare release 2.12.29 (#2136) --- metaflow/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/version.py b/metaflow/version.py index 11c1fdd73ce..0fe0ee7dc5a 100644 --- a/metaflow/version.py +++ b/metaflow/version.py @@ -1 +1 @@ -metaflow_version = "2.12.28" +metaflow_version = "2.12.29" From a3312cc0d406591c4b59261ed72df736481982e4 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen <64256562+saikonen@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:56:37 +0200 Subject: [PATCH 12/22] Revert "better error message with dump (#2130)" (#2141) This reverts commit 65fd88891dade200f5697f17e57787638ee97a98. --- metaflow/cli.py | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 1fc6a14953f..a318b84a3ec 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -282,31 +282,21 @@ def dump(obj, input_path, private=None, max_value_size=None, include=None, file= else: ds_list = list(datastore_set) # get all tasks - tasks_processed = False for ds in ds_list: - if ds is not None: - tasks_processed = True - echo( - "Dumping output of run_id=*{run_id}* " - "step=*{step}* task_id=*{task_id}*".format( - run_id=ds.run_id, step=ds.step_name, task_id=ds.task_id - ), - fg="magenta", - ) - - if file is None: - echo_always( - ds.format(**kwargs), - highlight="green", - highlight_bold=False, - err=False, - ) - else: - output[ds.pathspec] = ds.to_dict(**kwargs) + echo( + "Dumping output of run_id=*{run_id}* " + "step=*{step}* task_id=*{task_id}*".format( + run_id=ds.run_id, step=ds.step_name, task_id=ds.task_id + ), + fg="magenta", + ) - if not tasks_processed: - echo(f"No task(s) found for pathspec {input_path}", fg="red") - return + if file is None: + echo_always( + ds.format(**kwargs), highlight="green", highlight_bold=False, err=False + ) + else: + output[ds.pathspec] = ds.to_dict(**kwargs) if file is not None: with open(file, "wb") as f: From 05f9756077fc98d10be47b434863c101829544f9 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen <64256562+saikonen@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:59:59 +0200 Subject: [PATCH 13/22] bump version to 2.12.30 (#2142) --- metaflow/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/version.py b/metaflow/version.py index 0fe0ee7dc5a..14696aa8064 100644 --- a/metaflow/version.py +++ b/metaflow/version.py @@ -1 +1 @@ -metaflow_version = "2.12.29" +metaflow_version = "2.12.30" From d8c7cc5afc74a9d4d0adb35f47c54d4d4d0bb315 Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Thu, 21 Nov 2024 13:55:56 -0800 Subject: [PATCH 14/22] [jobsets] Fix killing jobsets using deletion (#2149) --- .../plugins/kubernetes/kubernetes_jobsets.py | 71 +++++++++++-------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index cf6c6affe2b..c84d655fc42 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -4,7 +4,6 @@ import random import time from collections import namedtuple - from metaflow.exception import MetaflowException from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION from metaflow.tracing import inject_tracing_vars @@ -320,33 +319,49 @@ def _fetch_pod(self): def kill(self): plural = "jobsets" client = self._client.get() - # Get the jobset - with client.ApiClient() as api_client: - api_instance = client.CustomObjectsApi(api_client) - try: - obj = api_instance.get_namespaced_custom_object( - group=self._group, - version=self._version, - namespace=self._namespace, - plural=plural, - name=self._name, - ) - - # Suspend the jobset - obj["spec"]["suspend"] = True - - api_instance.replace_namespaced_custom_object( - group=self._group, - version=self._version, - namespace=self._namespace, - plural=plural, - name=obj["metadata"]["name"], - body=obj, - ) - except Exception as e: - raise KubernetesJobsetException( - "Exception when suspending existing jobset: %s\n" % e - ) + try: + # Killing the control pod will trigger the jobset to mark everything as failed. + # Since jobsets have a successPolicy set to `All` which ensures that everything has + # to succeed for the jobset to succeed. + from kubernetes.stream import stream + + control_pod = self._fetch_pod() + stream( + client.CoreV1Api().connect_get_namespaced_pod_exec, + name=control_pod["metadata"]["name"], + namespace=control_pod["metadata"]["namespace"], + command=[ + "/bin/sh", + "-c", + "/sbin/killall5", + ], + stderr=True, + stdin=False, + stdout=True, + tty=False, + ) + except Exception as e: + with client.ApiClient() as api_client: + # If we are unable to kill the control pod then + # Delete the jobset to kill the subsequent pods. + # There are a few reasons for deleting a jobset to kill it : + # 1. Jobset has a `suspend` attribute to suspend it's execution, but this + # doesn't play nicely when jobsets are deployed with other components like kueue. + # 2. Jobset doesn't play nicely when we mutate status + # 3. Deletion is a gaurenteed way of removing any pods. + api_instance = client.CustomObjectsApi(api_client) + try: + api_instance.delete_namespaced_custom_object( + group=self._group, + version=self._version, + namespace=self._namespace, + plural=plural, + name=self._name, + ) + except Exception as e: + raise KubernetesJobsetException( + "Exception when deleting existing jobset: %s\n" % e + ) @property def id(self): From 50298d70b4e3be4e480bec08f5e9bdb2a7eef1eb Mon Sep 17 00:00:00 2001 From: KaylaSeeley <42901681+KaylaSeeley@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:22:20 -0800 Subject: [PATCH 15/22] Deploy time triggers (#2133) * trigger_on_finish sorta works * trigger deco works for event * trigger events changes * run black * ok this one ran black * black ran for real * deleting some things i missed * add format_deploytime_value() to both decos * fixes error with types * ran black * fixes failing cases * remove json.loads and add tuple to param type * ran black * function within parameter * Delete local config file * fixing borked cases * refactor code * shouldn't be changes to flowspec.py * fixing bugs * add deploy time trigger inits to argo parameter handling (#2146) * run black * undo modifications to to_pod * reset util file * pr comment * remove print --------- Co-authored-by: kayla seeley Co-authored-by: Sakari Ikonen <64256562+saikonen@users.noreply.github.com> --- metaflow/parameters.py | 10 +- metaflow/plugins/argo/argo_workflows.py | 12 +- metaflow/plugins/events_decorator.py | 325 ++++++++++++++++++------ 3 files changed, 269 insertions(+), 78 deletions(-) diff --git a/metaflow/parameters.py b/metaflow/parameters.py index fe0dabbda3f..e5778e6cd1e 100644 --- a/metaflow/parameters.py +++ b/metaflow/parameters.py @@ -151,6 +151,7 @@ def __call__(self, deploy_time=False): return self._check_type(val, deploy_time) def _check_type(self, val, deploy_time): + # it is easy to introduce a deploy-time function that accidentally # returns a value whose type is not compatible with what is defined # in Parameter. Let's catch those mistakes early here, instead of @@ -158,7 +159,7 @@ def _check_type(self, val, deploy_time): # note: this doesn't work with long in Python2 or types defined as # click types, e.g. click.INT - TYPES = {bool: "bool", int: "int", float: "float", list: "list"} + TYPES = {bool: "bool", int: "int", float: "float", list: "list", dict: "dict"} msg = ( "The value returned by the deploy-time function for " @@ -166,7 +167,12 @@ def _check_type(self, val, deploy_time): % (self.parameter_name, self.field) ) - if self.parameter_type in TYPES: + if isinstance(self.parameter_type, list): + if not any(isinstance(val, x) for x in self.parameter_type): + msg += "Expected one of the following %s." % TYPES[self.parameter_type] + raise ParameterFieldTypeMismatch(msg) + return str(val) if self.return_str else val + elif self.parameter_type in TYPES: if type(val) != self.parameter_type: msg += "Expected a %s." % TYPES[self.parameter_type] raise ParameterFieldTypeMismatch(msg) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index c4e8cbd6c77..05371eeca69 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -522,7 +522,9 @@ def _process_triggers(self): params = set( [param.name.lower() for var, param in self.flow._get_parameters()] ) - for event in self.flow._flow_decorators.get("trigger")[0].triggers: + trigger_deco = self.flow._flow_decorators.get("trigger")[0] + trigger_deco.format_deploytime_value() + for event in trigger_deco.triggers: parameters = {} # TODO: Add a check to guard against names starting with numerals(?) if not re.match(r"^[A-Za-z0-9_.-]+$", event["name"]): @@ -562,9 +564,11 @@ def _process_triggers(self): # @trigger_on_finish decorator if self.flow._flow_decorators.get("trigger_on_finish"): - for event in self.flow._flow_decorators.get("trigger_on_finish")[ - 0 - ].triggers: + trigger_on_finish_deco = self.flow._flow_decorators.get( + "trigger_on_finish" + )[0] + trigger_on_finish_deco.format_deploytime_value() + for event in trigger_on_finish_deco.triggers: # Actual filters are deduced here since we don't have access to # the current object in the @trigger_on_finish decorator. triggers.append( diff --git a/metaflow/plugins/events_decorator.py b/metaflow/plugins/events_decorator.py index baa6320b0ba..c9090f547fb 100644 --- a/metaflow/plugins/events_decorator.py +++ b/metaflow/plugins/events_decorator.py @@ -1,9 +1,11 @@ import re +import json from metaflow import current from metaflow.decorators import FlowDecorator from metaflow.exception import MetaflowException from metaflow.util import is_stringish +from metaflow.parameters import DeployTimeField, deploy_time_eval # TODO: Support dynamic parameter mapping through a context object that exposes # flow name and user name similar to parameter context @@ -68,6 +70,75 @@ class TriggerDecorator(FlowDecorator): "options": {}, } + def process_event_name(self, event): + if is_stringish(event): + return {"name": str(event)} + elif isinstance(event, dict): + if "name" not in event: + raise MetaflowException( + "The *event* attribute for *@trigger* is missing the *name* key." + ) + if callable(event["name"]) and not isinstance( + event["name"], DeployTimeField + ): + event["name"] = DeployTimeField( + "event_name", str, None, event["name"], False + ) + event["parameters"] = self.process_parameters(event.get("parameters", {})) + return event + elif callable(event) and not isinstance(event, DeployTimeField): + return DeployTimeField("event", [str, dict], None, event, False) + else: + raise MetaflowException( + "Incorrect format for *event* attribute in *@trigger* decorator. " + "Supported formats are string and dictionary - \n" + "@trigger(event='foo') or @trigger(event={'name': 'foo', " + "'parameters': {'alpha': 'beta'}})" + ) + + def process_parameters(self, parameters): + new_param_values = {} + if isinstance(parameters, (list, tuple)): + for mapping in parameters: + if is_stringish(mapping): + new_param_values[mapping] = mapping + elif callable(mapping) and not isinstance(mapping, DeployTimeField): + mapping = DeployTimeField( + "parameter_val", str, None, mapping, False + ) + new_param_values[mapping] = mapping + elif isinstance(mapping, (list, tuple)) and len(mapping) == 2: + if callable(mapping[0]) and not isinstance( + mapping[0], DeployTimeField + ): + mapping[0] = DeployTimeField( + "parameter_val", str, None, mapping[0], False + ) + if callable(mapping[1]) and not isinstance( + mapping[1], DeployTimeField + ): + mapping[1] = DeployTimeField( + "parameter_val", str, None, mapping[1], False + ) + new_param_values[mapping[0]] = mapping[1] + else: + raise MetaflowException( + "The *parameters* attribute for event is invalid. " + "It should be a list/tuple of strings and lists/tuples of size 2" + ) + elif callable(parameters) and not isinstance(parameters, DeployTimeField): + return DeployTimeField( + "parameters", [list, dict, tuple], None, parameters, False + ) + elif isinstance(parameters, dict): + for key, value in parameters.items(): + if callable(key) and not isinstance(key, DeployTimeField): + key = DeployTimeField("flow_parameter", str, None, key, False) + if callable(value) and not isinstance(value, DeployTimeField): + value = DeployTimeField("signal_parameter", str, None, value, False) + new_param_values[key] = value + return new_param_values + def flow_init( self, flow_name, @@ -86,41 +157,9 @@ def flow_init( "attributes in *@trigger* decorator." ) elif self.attributes["event"]: - # event attribute supports the following formats - - # 1. event='table.prod_db.members' - # 2. event={'name': 'table.prod_db.members', - # 'parameters': {'alpha': 'member_weight'}} - if is_stringish(self.attributes["event"]): - self.triggers.append({"name": str(self.attributes["event"])}) - elif isinstance(self.attributes["event"], dict): - if "name" not in self.attributes["event"]: - raise MetaflowException( - "The *event* attribute for *@trigger* is missing the " - "*name* key." - ) - param_value = self.attributes["event"].get("parameters", {}) - if isinstance(param_value, (list, tuple)): - new_param_value = {} - for mapping in param_value: - if is_stringish(mapping): - new_param_value[mapping] = mapping - elif isinstance(mapping, (list, tuple)) and len(mapping) == 2: - new_param_value[mapping[0]] = mapping[1] - else: - raise MetaflowException( - "The *parameters* attribute for event '%s' is invalid. " - "It should be a list/tuple of strings and lists/tuples " - "of size 2" % self.attributes["event"]["name"] - ) - self.attributes["event"]["parameters"] = new_param_value - self.triggers.append(self.attributes["event"]) - else: - raise MetaflowException( - "Incorrect format for *event* attribute in *@trigger* decorator. " - "Supported formats are string and dictionary - \n" - "@trigger(event='foo') or @trigger(event={'name': 'foo', " - "'parameters': {'alpha': 'beta'}})" - ) + event = self.attributes["event"] + processed_event = self.process_event_name(event) + self.triggers.append(processed_event) elif self.attributes["events"]: # events attribute supports the following formats - # 1. events=[{'name': 'table.prod_db.members', @@ -128,43 +167,17 @@ def flow_init( # {'name': 'table.prod_db.metadata', # 'parameters': {'beta': 'grade'}}] if isinstance(self.attributes["events"], list): + # process every event in events for event in self.attributes["events"]: - if is_stringish(event): - self.triggers.append({"name": str(event)}) - elif isinstance(event, dict): - if "name" not in event: - raise MetaflowException( - "One or more events in *events* attribute for " - "*@trigger* are missing the *name* key." - ) - param_value = event.get("parameters", {}) - if isinstance(param_value, (list, tuple)): - new_param_value = {} - for mapping in param_value: - if is_stringish(mapping): - new_param_value[mapping] = mapping - elif ( - isinstance(mapping, (list, tuple)) - and len(mapping) == 2 - ): - new_param_value[mapping[0]] = mapping[1] - else: - raise MetaflowException( - "The *parameters* attribute for event '%s' is " - "invalid. It should be a list/tuple of strings " - "and lists/tuples of size 2" % event["name"] - ) - event["parameters"] = new_param_value - self.triggers.append(event) - else: - raise MetaflowException( - "One or more events in *events* attribute in *@trigger* " - "decorator have an incorrect format. Supported format " - "is dictionary - \n" - "@trigger(events=[{'name': 'foo', 'parameters': {'alpha': " - "'beta'}}, {'name': 'bar', 'parameters': " - "{'gamma': 'kappa'}}])" - ) + processed_event = self.process_event_name(event) + self.triggers.append("processed event", processed_event) + elif callable(self.attributes["events"]) and not isinstance( + self.attributes["events"], DeployTimeField + ): + trig = DeployTimeField( + "events", list, None, self.attributes["events"], False + ) + self.triggers.append(trig) else: raise MetaflowException( "Incorrect format for *events* attribute in *@trigger* decorator. " @@ -178,7 +191,12 @@ def flow_init( raise MetaflowException("No event(s) specified in *@trigger* decorator.") # same event shouldn't occur more than once - names = [x["name"] for x in self.triggers] + names = [ + x["name"] + for x in self.triggers + if not isinstance(x, DeployTimeField) + and not isinstance(x["name"], DeployTimeField) + ] if len(names) != len(set(names)): raise MetaflowException( "Duplicate event names defined in *@trigger* decorator." @@ -188,6 +206,104 @@ def flow_init( # TODO: Handle scenario for local testing using --trigger. + def format_deploytime_value(self): + new_triggers = [] + for trigger in self.triggers: + # Case where trigger is a function that returns a list of events + # Need to do this bc we need to iterate over list later + if isinstance(trigger, DeployTimeField): + evaluated_trigger = deploy_time_eval(trigger) + if isinstance(evaluated_trigger, dict): + trigger = evaluated_trigger + elif isinstance(evaluated_trigger, str): + trigger = {"name": evaluated_trigger} + if isinstance(evaluated_trigger, list): + for trig in evaluated_trigger: + if is_stringish(trig): + new_triggers.append({"name": trig}) + else: # dict or another deploytimefield + new_triggers.append(trig) + else: + new_triggers.append(trigger) + else: + new_triggers.append(trigger) + + self.triggers = new_triggers + for trigger in self.triggers: + old_trigger = trigger + trigger_params = trigger.get("parameters", {}) + # Case where param is a function (can return list or dict) + if isinstance(trigger_params, DeployTimeField): + trigger_params = deploy_time_eval(trigger_params) + # If params is a list of strings, convert to dict with same key and value + if isinstance(trigger_params, (list, tuple)): + new_trigger_params = {} + for mapping in trigger_params: + if is_stringish(mapping) or callable(mapping): + new_trigger_params[mapping] = mapping + elif callable(mapping) and not isinstance(mapping, DeployTimeField): + mapping = DeployTimeField( + "parameter_val", str, None, mapping, False + ) + new_trigger_params[mapping] = mapping + elif isinstance(mapping, (list, tuple)) and len(mapping) == 2: + if callable(mapping[0]) and not isinstance( + mapping[0], DeployTimeField + ): + mapping[0] = DeployTimeField( + "parameter_val", + str, + None, + mapping[1], + False, + ) + if callable(mapping[1]) and not isinstance( + mapping[1], DeployTimeField + ): + mapping[1] = DeployTimeField( + "parameter_val", + str, + None, + mapping[1], + False, + ) + + new_trigger_params[mapping[0]] = mapping[1] + else: + raise MetaflowException( + "The *parameters* attribute for event '%s' is invalid. " + "It should be a list/tuple of strings and lists/tuples " + "of size 2" % self.attributes["event"]["name"] + ) + trigger_params = new_trigger_params + trigger["parameters"] = trigger_params + + trigger_name = trigger.get("name") + # Case where just the name is a function (always a str) + if isinstance(trigger_name, DeployTimeField): + trigger_name = deploy_time_eval(trigger_name) + trigger["name"] = trigger_name + + # Third layer + # {name:, parameters:[func, ..., ...]} + # {name:, parameters:{func : func2}} + for trigger in self.triggers: + old_trigger = trigger + trigger_params = trigger.get("parameters", {}) + new_trigger_params = {} + for key, value in trigger_params.items(): + if isinstance(value, DeployTimeField) and key is value: + evaluated_param = deploy_time_eval(value) + new_trigger_params[evaluated_param] = evaluated_param + elif isinstance(value, DeployTimeField): + new_trigger_params[key] = deploy_time_eval(value) + elif isinstance(key, DeployTimeField): + new_trigger_params[deploy_time_eval(key)] = value + else: + new_trigger_params[key] = value + trigger["parameters"] = new_trigger_params + self.triggers[self.triggers.index(old_trigger)] = trigger + class TriggerOnFinishDecorator(FlowDecorator): """ @@ -312,6 +428,13 @@ def flow_init( "The *project_branch* attribute of the *flow* is not a string" ) self.triggers.append(result) + elif callable(self.attributes["flow"]) and not isinstance( + self.attributes["flow"], DeployTimeField + ): + trig = DeployTimeField( + "fq_name", [str, dict], None, self.attributes["flow"], False + ) + self.triggers.append(trig) else: raise MetaflowException( "Incorrect type for *flow* attribute in *@trigger_on_finish* " @@ -369,6 +492,13 @@ def flow_init( "Supported type is string or Dict[str, str]- \n" "@trigger_on_finish(flows=['FooFlow', 'BarFlow']" ) + elif callable(self.attributes["flows"]) and not isinstance( + self.attributes["flows"], DeployTimeField + ): + trig = DeployTimeField( + "flows", list, None, self.attributes["flows"], False + ) + self.triggers.append(trig) else: raise MetaflowException( "Incorrect type for *flows* attribute in *@trigger_on_finish* " @@ -383,6 +513,8 @@ def flow_init( # Make triggers @project aware for trigger in self.triggers: + if isinstance(trigger, DeployTimeField): + continue if trigger["fq_name"].count(".") == 0: # fully qualified name is just the flow name trigger["flow"] = trigger["fq_name"] @@ -427,5 +559,54 @@ def flow_init( run_objs.append(run_obj) current._update_env({"trigger": Trigger.from_runs(run_objs)}) + def _parse_fq_name(self, trigger): + if isinstance(trigger, DeployTimeField): + trigger["fq_name"] = deploy_time_eval(trigger["fq_name"]) + if trigger["fq_name"].count(".") == 0: + # fully qualified name is just the flow name + trigger["flow"] = trigger["fq_name"] + elif trigger["fq_name"].count(".") >= 2: + # fully qualified name is of the format - project.branch.flow_name + trigger["project"], tail = trigger["fq_name"].split(".", maxsplit=1) + trigger["branch"], trigger["flow"] = tail.rsplit(".", maxsplit=1) + else: + raise MetaflowException( + "Incorrect format for *flow* in *@trigger_on_finish* " + "decorator. Specify either just the *flow_name* or a fully " + "qualified name like *project_name.branch_name.flow_name*." + ) + if not re.match(r"^[A-Za-z0-9_]+$", trigger["flow"]): + raise MetaflowException( + "Invalid flow name *%s* in *@trigger_on_finish* " + "decorator. Only alphanumeric characters and " + "underscores(_) are allowed." % trigger["flow"] + ) + return trigger + + def format_deploytime_value(self): + for trigger in self.triggers: + # Case were trigger is a function that returns a list + # Need to do this bc we need to iterate over list and process + if isinstance(trigger, DeployTimeField): + deploy_value = deploy_time_eval(trigger) + if isinstance(deploy_value, list): + self.triggers = deploy_value + else: + break + for trigger in self.triggers: + # Entire trigger is a function (returns either string or dict) + old_trig = trigger + if isinstance(trigger, DeployTimeField): + trigger = deploy_time_eval(trigger) + if isinstance(trigger, dict): + trigger["fq_name"] = trigger.get("name") + trigger["project"] = trigger.get("project") + trigger["branch"] = trigger.get("project_branch") + # We also added this bc it won't be formatted yet + if isinstance(trigger, str): + trigger = {"fq_name": trigger} + trigger = self._parse_fq_name(trigger) + self.triggers[self.triggers.index(old_trig)] = trigger + def get_top_level_options(self): return list(self._option_values.items()) From a3d087effed003211402ee09aebf7cc25de47eb3 Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Fri, 22 Nov 2024 11:20:02 -0800 Subject: [PATCH 16/22] Bump version to 2.12.31 --- metaflow/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/version.py b/metaflow/version.py index 14696aa8064..9a35e9c0e01 100644 --- a/metaflow/version.py +++ b/metaflow/version.py @@ -1 +1 @@ -metaflow_version = "2.12.30" +metaflow_version = "2.12.31" From 15c86b9db509edd14600b9a017865bc0348aba78 Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Mon, 25 Nov 2024 18:23:57 -0800 Subject: [PATCH 17/22] [runtime] run-hbs for local runtime's `run` command (#2150) --- metaflow/cli.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index a318b84a3ec..498ea1b74b2 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -802,20 +802,21 @@ def run( "msg": "Starting run", }, ) - runtime.print_workflow_info() - runtime.persist_constants() - - if runner_attribute_file: - with open(runner_attribute_file, "w", encoding="utf-8") as f: - json.dump( - { - "run_id": runtime.run_id, - "flow_name": obj.flow.name, - "metadata": obj.metadata.metadata_str(), - }, - f, - ) - runtime.execute() + with runtime.run_heartbeat(): + runtime.print_workflow_info() + runtime.persist_constants() + + if runner_attribute_file: + with open(runner_attribute_file, "w", encoding="utf-8") as f: + json.dump( + { + "run_id": runtime.run_id, + "flow_name": obj.flow.name, + "metadata": obj.metadata.metadata_str(), + }, + f, + ) + runtime.execute() def write_file(file_path, content): From 6355a1b1d24e438af87c63b1b895b640de315d76 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen <64256562+saikonen@users.noreply.github.com> Date: Tue, 26 Nov 2024 20:57:40 +0200 Subject: [PATCH 18/22] bump version to 2.12.32 (#2151) --- metaflow/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/version.py b/metaflow/version.py index 9a35e9c0e01..ebded2698ef 100644 --- a/metaflow/version.py +++ b/metaflow/version.py @@ -1 +1 @@ -metaflow_version = "2.12.31" +metaflow_version = "2.12.32" From 7b594b40f0632dd8219ba890280d155234effad2 Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Tue, 26 Nov 2024 12:33:48 -0800 Subject: [PATCH 19/22] [parallel] exclude `_parallel_ubf_iter` from `merge_artifacts` (#2152) --- metaflow/flowspec.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index 0c7ffd1f128..7a97b9714fa 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -38,6 +38,7 @@ "_unbounded_foreach", "_control_mapper_tasks", "_control_task_is_mapper_zero", + "_parallel_ubf_iter", ] ) From f3a1857880c4325788b72c3f18466a1c10d98aac Mon Sep 17 00:00:00 2001 From: Savin Date: Wed, 27 Nov 2024 10:47:33 -0800 Subject: [PATCH 20/22] Revert "Deploy time triggers (#2133)" (#2153) This reverts commit 50298d70b4e3be4e480bec08f5e9bdb2a7eef1eb. --- metaflow/parameters.py | 10 +- metaflow/plugins/argo/argo_workflows.py | 12 +- metaflow/plugins/events_decorator.py | 325 ++++++------------------ 3 files changed, 78 insertions(+), 269 deletions(-) diff --git a/metaflow/parameters.py b/metaflow/parameters.py index e5778e6cd1e..fe0dabbda3f 100644 --- a/metaflow/parameters.py +++ b/metaflow/parameters.py @@ -151,7 +151,6 @@ def __call__(self, deploy_time=False): return self._check_type(val, deploy_time) def _check_type(self, val, deploy_time): - # it is easy to introduce a deploy-time function that accidentally # returns a value whose type is not compatible with what is defined # in Parameter. Let's catch those mistakes early here, instead of @@ -159,7 +158,7 @@ def _check_type(self, val, deploy_time): # note: this doesn't work with long in Python2 or types defined as # click types, e.g. click.INT - TYPES = {bool: "bool", int: "int", float: "float", list: "list", dict: "dict"} + TYPES = {bool: "bool", int: "int", float: "float", list: "list"} msg = ( "The value returned by the deploy-time function for " @@ -167,12 +166,7 @@ def _check_type(self, val, deploy_time): % (self.parameter_name, self.field) ) - if isinstance(self.parameter_type, list): - if not any(isinstance(val, x) for x in self.parameter_type): - msg += "Expected one of the following %s." % TYPES[self.parameter_type] - raise ParameterFieldTypeMismatch(msg) - return str(val) if self.return_str else val - elif self.parameter_type in TYPES: + if self.parameter_type in TYPES: if type(val) != self.parameter_type: msg += "Expected a %s." % TYPES[self.parameter_type] raise ParameterFieldTypeMismatch(msg) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 05371eeca69..c4e8cbd6c77 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -522,9 +522,7 @@ def _process_triggers(self): params = set( [param.name.lower() for var, param in self.flow._get_parameters()] ) - trigger_deco = self.flow._flow_decorators.get("trigger")[0] - trigger_deco.format_deploytime_value() - for event in trigger_deco.triggers: + for event in self.flow._flow_decorators.get("trigger")[0].triggers: parameters = {} # TODO: Add a check to guard against names starting with numerals(?) if not re.match(r"^[A-Za-z0-9_.-]+$", event["name"]): @@ -564,11 +562,9 @@ def _process_triggers(self): # @trigger_on_finish decorator if self.flow._flow_decorators.get("trigger_on_finish"): - trigger_on_finish_deco = self.flow._flow_decorators.get( - "trigger_on_finish" - )[0] - trigger_on_finish_deco.format_deploytime_value() - for event in trigger_on_finish_deco.triggers: + for event in self.flow._flow_decorators.get("trigger_on_finish")[ + 0 + ].triggers: # Actual filters are deduced here since we don't have access to # the current object in the @trigger_on_finish decorator. triggers.append( diff --git a/metaflow/plugins/events_decorator.py b/metaflow/plugins/events_decorator.py index c9090f547fb..baa6320b0ba 100644 --- a/metaflow/plugins/events_decorator.py +++ b/metaflow/plugins/events_decorator.py @@ -1,11 +1,9 @@ import re -import json from metaflow import current from metaflow.decorators import FlowDecorator from metaflow.exception import MetaflowException from metaflow.util import is_stringish -from metaflow.parameters import DeployTimeField, deploy_time_eval # TODO: Support dynamic parameter mapping through a context object that exposes # flow name and user name similar to parameter context @@ -70,75 +68,6 @@ class TriggerDecorator(FlowDecorator): "options": {}, } - def process_event_name(self, event): - if is_stringish(event): - return {"name": str(event)} - elif isinstance(event, dict): - if "name" not in event: - raise MetaflowException( - "The *event* attribute for *@trigger* is missing the *name* key." - ) - if callable(event["name"]) and not isinstance( - event["name"], DeployTimeField - ): - event["name"] = DeployTimeField( - "event_name", str, None, event["name"], False - ) - event["parameters"] = self.process_parameters(event.get("parameters", {})) - return event - elif callable(event) and not isinstance(event, DeployTimeField): - return DeployTimeField("event", [str, dict], None, event, False) - else: - raise MetaflowException( - "Incorrect format for *event* attribute in *@trigger* decorator. " - "Supported formats are string and dictionary - \n" - "@trigger(event='foo') or @trigger(event={'name': 'foo', " - "'parameters': {'alpha': 'beta'}})" - ) - - def process_parameters(self, parameters): - new_param_values = {} - if isinstance(parameters, (list, tuple)): - for mapping in parameters: - if is_stringish(mapping): - new_param_values[mapping] = mapping - elif callable(mapping) and not isinstance(mapping, DeployTimeField): - mapping = DeployTimeField( - "parameter_val", str, None, mapping, False - ) - new_param_values[mapping] = mapping - elif isinstance(mapping, (list, tuple)) and len(mapping) == 2: - if callable(mapping[0]) and not isinstance( - mapping[0], DeployTimeField - ): - mapping[0] = DeployTimeField( - "parameter_val", str, None, mapping[0], False - ) - if callable(mapping[1]) and not isinstance( - mapping[1], DeployTimeField - ): - mapping[1] = DeployTimeField( - "parameter_val", str, None, mapping[1], False - ) - new_param_values[mapping[0]] = mapping[1] - else: - raise MetaflowException( - "The *parameters* attribute for event is invalid. " - "It should be a list/tuple of strings and lists/tuples of size 2" - ) - elif callable(parameters) and not isinstance(parameters, DeployTimeField): - return DeployTimeField( - "parameters", [list, dict, tuple], None, parameters, False - ) - elif isinstance(parameters, dict): - for key, value in parameters.items(): - if callable(key) and not isinstance(key, DeployTimeField): - key = DeployTimeField("flow_parameter", str, None, key, False) - if callable(value) and not isinstance(value, DeployTimeField): - value = DeployTimeField("signal_parameter", str, None, value, False) - new_param_values[key] = value - return new_param_values - def flow_init( self, flow_name, @@ -157,9 +86,41 @@ def flow_init( "attributes in *@trigger* decorator." ) elif self.attributes["event"]: - event = self.attributes["event"] - processed_event = self.process_event_name(event) - self.triggers.append(processed_event) + # event attribute supports the following formats - + # 1. event='table.prod_db.members' + # 2. event={'name': 'table.prod_db.members', + # 'parameters': {'alpha': 'member_weight'}} + if is_stringish(self.attributes["event"]): + self.triggers.append({"name": str(self.attributes["event"])}) + elif isinstance(self.attributes["event"], dict): + if "name" not in self.attributes["event"]: + raise MetaflowException( + "The *event* attribute for *@trigger* is missing the " + "*name* key." + ) + param_value = self.attributes["event"].get("parameters", {}) + if isinstance(param_value, (list, tuple)): + new_param_value = {} + for mapping in param_value: + if is_stringish(mapping): + new_param_value[mapping] = mapping + elif isinstance(mapping, (list, tuple)) and len(mapping) == 2: + new_param_value[mapping[0]] = mapping[1] + else: + raise MetaflowException( + "The *parameters* attribute for event '%s' is invalid. " + "It should be a list/tuple of strings and lists/tuples " + "of size 2" % self.attributes["event"]["name"] + ) + self.attributes["event"]["parameters"] = new_param_value + self.triggers.append(self.attributes["event"]) + else: + raise MetaflowException( + "Incorrect format for *event* attribute in *@trigger* decorator. " + "Supported formats are string and dictionary - \n" + "@trigger(event='foo') or @trigger(event={'name': 'foo', " + "'parameters': {'alpha': 'beta'}})" + ) elif self.attributes["events"]: # events attribute supports the following formats - # 1. events=[{'name': 'table.prod_db.members', @@ -167,17 +128,43 @@ def flow_init( # {'name': 'table.prod_db.metadata', # 'parameters': {'beta': 'grade'}}] if isinstance(self.attributes["events"], list): - # process every event in events for event in self.attributes["events"]: - processed_event = self.process_event_name(event) - self.triggers.append("processed event", processed_event) - elif callable(self.attributes["events"]) and not isinstance( - self.attributes["events"], DeployTimeField - ): - trig = DeployTimeField( - "events", list, None, self.attributes["events"], False - ) - self.triggers.append(trig) + if is_stringish(event): + self.triggers.append({"name": str(event)}) + elif isinstance(event, dict): + if "name" not in event: + raise MetaflowException( + "One or more events in *events* attribute for " + "*@trigger* are missing the *name* key." + ) + param_value = event.get("parameters", {}) + if isinstance(param_value, (list, tuple)): + new_param_value = {} + for mapping in param_value: + if is_stringish(mapping): + new_param_value[mapping] = mapping + elif ( + isinstance(mapping, (list, tuple)) + and len(mapping) == 2 + ): + new_param_value[mapping[0]] = mapping[1] + else: + raise MetaflowException( + "The *parameters* attribute for event '%s' is " + "invalid. It should be a list/tuple of strings " + "and lists/tuples of size 2" % event["name"] + ) + event["parameters"] = new_param_value + self.triggers.append(event) + else: + raise MetaflowException( + "One or more events in *events* attribute in *@trigger* " + "decorator have an incorrect format. Supported format " + "is dictionary - \n" + "@trigger(events=[{'name': 'foo', 'parameters': {'alpha': " + "'beta'}}, {'name': 'bar', 'parameters': " + "{'gamma': 'kappa'}}])" + ) else: raise MetaflowException( "Incorrect format for *events* attribute in *@trigger* decorator. " @@ -191,12 +178,7 @@ def flow_init( raise MetaflowException("No event(s) specified in *@trigger* decorator.") # same event shouldn't occur more than once - names = [ - x["name"] - for x in self.triggers - if not isinstance(x, DeployTimeField) - and not isinstance(x["name"], DeployTimeField) - ] + names = [x["name"] for x in self.triggers] if len(names) != len(set(names)): raise MetaflowException( "Duplicate event names defined in *@trigger* decorator." @@ -206,104 +188,6 @@ def flow_init( # TODO: Handle scenario for local testing using --trigger. - def format_deploytime_value(self): - new_triggers = [] - for trigger in self.triggers: - # Case where trigger is a function that returns a list of events - # Need to do this bc we need to iterate over list later - if isinstance(trigger, DeployTimeField): - evaluated_trigger = deploy_time_eval(trigger) - if isinstance(evaluated_trigger, dict): - trigger = evaluated_trigger - elif isinstance(evaluated_trigger, str): - trigger = {"name": evaluated_trigger} - if isinstance(evaluated_trigger, list): - for trig in evaluated_trigger: - if is_stringish(trig): - new_triggers.append({"name": trig}) - else: # dict or another deploytimefield - new_triggers.append(trig) - else: - new_triggers.append(trigger) - else: - new_triggers.append(trigger) - - self.triggers = new_triggers - for trigger in self.triggers: - old_trigger = trigger - trigger_params = trigger.get("parameters", {}) - # Case where param is a function (can return list or dict) - if isinstance(trigger_params, DeployTimeField): - trigger_params = deploy_time_eval(trigger_params) - # If params is a list of strings, convert to dict with same key and value - if isinstance(trigger_params, (list, tuple)): - new_trigger_params = {} - for mapping in trigger_params: - if is_stringish(mapping) or callable(mapping): - new_trigger_params[mapping] = mapping - elif callable(mapping) and not isinstance(mapping, DeployTimeField): - mapping = DeployTimeField( - "parameter_val", str, None, mapping, False - ) - new_trigger_params[mapping] = mapping - elif isinstance(mapping, (list, tuple)) and len(mapping) == 2: - if callable(mapping[0]) and not isinstance( - mapping[0], DeployTimeField - ): - mapping[0] = DeployTimeField( - "parameter_val", - str, - None, - mapping[1], - False, - ) - if callable(mapping[1]) and not isinstance( - mapping[1], DeployTimeField - ): - mapping[1] = DeployTimeField( - "parameter_val", - str, - None, - mapping[1], - False, - ) - - new_trigger_params[mapping[0]] = mapping[1] - else: - raise MetaflowException( - "The *parameters* attribute for event '%s' is invalid. " - "It should be a list/tuple of strings and lists/tuples " - "of size 2" % self.attributes["event"]["name"] - ) - trigger_params = new_trigger_params - trigger["parameters"] = trigger_params - - trigger_name = trigger.get("name") - # Case where just the name is a function (always a str) - if isinstance(trigger_name, DeployTimeField): - trigger_name = deploy_time_eval(trigger_name) - trigger["name"] = trigger_name - - # Third layer - # {name:, parameters:[func, ..., ...]} - # {name:, parameters:{func : func2}} - for trigger in self.triggers: - old_trigger = trigger - trigger_params = trigger.get("parameters", {}) - new_trigger_params = {} - for key, value in trigger_params.items(): - if isinstance(value, DeployTimeField) and key is value: - evaluated_param = deploy_time_eval(value) - new_trigger_params[evaluated_param] = evaluated_param - elif isinstance(value, DeployTimeField): - new_trigger_params[key] = deploy_time_eval(value) - elif isinstance(key, DeployTimeField): - new_trigger_params[deploy_time_eval(key)] = value - else: - new_trigger_params[key] = value - trigger["parameters"] = new_trigger_params - self.triggers[self.triggers.index(old_trigger)] = trigger - class TriggerOnFinishDecorator(FlowDecorator): """ @@ -428,13 +312,6 @@ def flow_init( "The *project_branch* attribute of the *flow* is not a string" ) self.triggers.append(result) - elif callable(self.attributes["flow"]) and not isinstance( - self.attributes["flow"], DeployTimeField - ): - trig = DeployTimeField( - "fq_name", [str, dict], None, self.attributes["flow"], False - ) - self.triggers.append(trig) else: raise MetaflowException( "Incorrect type for *flow* attribute in *@trigger_on_finish* " @@ -492,13 +369,6 @@ def flow_init( "Supported type is string or Dict[str, str]- \n" "@trigger_on_finish(flows=['FooFlow', 'BarFlow']" ) - elif callable(self.attributes["flows"]) and not isinstance( - self.attributes["flows"], DeployTimeField - ): - trig = DeployTimeField( - "flows", list, None, self.attributes["flows"], False - ) - self.triggers.append(trig) else: raise MetaflowException( "Incorrect type for *flows* attribute in *@trigger_on_finish* " @@ -513,8 +383,6 @@ def flow_init( # Make triggers @project aware for trigger in self.triggers: - if isinstance(trigger, DeployTimeField): - continue if trigger["fq_name"].count(".") == 0: # fully qualified name is just the flow name trigger["flow"] = trigger["fq_name"] @@ -559,54 +427,5 @@ def flow_init( run_objs.append(run_obj) current._update_env({"trigger": Trigger.from_runs(run_objs)}) - def _parse_fq_name(self, trigger): - if isinstance(trigger, DeployTimeField): - trigger["fq_name"] = deploy_time_eval(trigger["fq_name"]) - if trigger["fq_name"].count(".") == 0: - # fully qualified name is just the flow name - trigger["flow"] = trigger["fq_name"] - elif trigger["fq_name"].count(".") >= 2: - # fully qualified name is of the format - project.branch.flow_name - trigger["project"], tail = trigger["fq_name"].split(".", maxsplit=1) - trigger["branch"], trigger["flow"] = tail.rsplit(".", maxsplit=1) - else: - raise MetaflowException( - "Incorrect format for *flow* in *@trigger_on_finish* " - "decorator. Specify either just the *flow_name* or a fully " - "qualified name like *project_name.branch_name.flow_name*." - ) - if not re.match(r"^[A-Za-z0-9_]+$", trigger["flow"]): - raise MetaflowException( - "Invalid flow name *%s* in *@trigger_on_finish* " - "decorator. Only alphanumeric characters and " - "underscores(_) are allowed." % trigger["flow"] - ) - return trigger - - def format_deploytime_value(self): - for trigger in self.triggers: - # Case were trigger is a function that returns a list - # Need to do this bc we need to iterate over list and process - if isinstance(trigger, DeployTimeField): - deploy_value = deploy_time_eval(trigger) - if isinstance(deploy_value, list): - self.triggers = deploy_value - else: - break - for trigger in self.triggers: - # Entire trigger is a function (returns either string or dict) - old_trig = trigger - if isinstance(trigger, DeployTimeField): - trigger = deploy_time_eval(trigger) - if isinstance(trigger, dict): - trigger["fq_name"] = trigger.get("name") - trigger["project"] = trigger.get("project") - trigger["branch"] = trigger.get("project_branch") - # We also added this bc it won't be formatted yet - if isinstance(trigger, str): - trigger = {"fq_name": trigger} - trigger = self._parse_fq_name(trigger) - self.triggers[self.triggers.index(old_trig)] = trigger - def get_top_level_options(self): return list(self._option_values.items()) From 8ccbbbe2b0a0c46ac7c081d235c6a946458094c9 Mon Sep 17 00:00:00 2001 From: Savin Date: Wed, 27 Nov 2024 14:55:03 -0800 Subject: [PATCH 21/22] Update version.py (#2154) --- metaflow/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/version.py b/metaflow/version.py index ebded2698ef..f0ff7113425 100644 --- a/metaflow/version.py +++ b/metaflow/version.py @@ -1 +1 @@ -metaflow_version = "2.12.32" +metaflow_version = "2.12.33" From 286f9ac09dc1ed5a98d4bad4c5e671424d9fbd29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20=C4=8Cack=C3=BD?= Date: Mon, 2 Dec 2024 18:28:19 +0100 Subject: [PATCH 22/22] Async runner improvements (#2056) * Add an async sigint handler * Improve async runner interface * Fix error handling in *handle_timeout * Fix async_read_from_file_when_ready * Simplify async_kill_processes_and_descendants * Group kill_process_and_descendants into single command * Call sigint manager only for live processes * Guard against empty pid list in sigint handler * Cleanup comments * Reimplement CommandManager.kill in terms of kill_processes_and_descendants * Use fifo for attribute file * Fix `temporary_fifo` type annotation * Fix `temporary_fifo` type annotation on py37 and py38 * Fix rewriting the current run metadata * Wrap execs in `async_kill_processes_and_descendants` in try/except --- .../argo/argo_workflows_deployer_objects.py | 12 +- .../step_functions_deployer_objects.py | 12 +- metaflow/runner/deployer_impl.py | 12 +- metaflow/runner/metaflow_runner.py | 62 ++--- metaflow/runner/subprocess_manager.py | 67 ++++- metaflow/runner/utils.py | 228 ++++++++++++++---- 6 files changed, 288 insertions(+), 105 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows_deployer_objects.py b/metaflow/plugins/argo/argo_workflows_deployer_objects.py index 6538b70310b..d2ec33bfff3 100644 --- a/metaflow/plugins/argo/argo_workflows_deployer_objects.py +++ b/metaflow/plugins/argo/argo_workflows_deployer_objects.py @@ -10,7 +10,7 @@ from metaflow.plugins.argo.argo_workflows import ArgoWorkflows from metaflow.runner.deployer import Deployer, DeployedFlow, TriggeredRun -from metaflow.runner.utils import get_lower_level_group, handle_timeout +from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo def generate_fake_flow_file_contents( @@ -341,18 +341,14 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun: Exception If there is an error during the trigger process. """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) - + with temporary_fifo() as (attribute_file_path, attribute_file_fd): # every subclass needs to have `self.deployer_kwargs` command = get_lower_level_group( self.deployer.api, self.deployer.top_level_kwargs, self.deployer.TYPE, self.deployer.deployer_kwargs, - ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) + ).trigger(deployer_attribute_file=attribute_file_path, **kwargs) pid = self.deployer.spm.run_command( [sys.executable, *command], @@ -363,7 +359,7 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun: command_obj = self.deployer.spm.get(pid) content = handle_timeout( - tfp_runner_attribute, command_obj, self.deployer.file_read_timeout + attribute_file_fd, command_obj, self.deployer.file_read_timeout ) if command_obj.process.returncode == 0: diff --git a/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py index 9b3528af01b..394d8739327 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +++ b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py @@ -6,7 +6,7 @@ from metaflow.plugins.aws.step_functions.step_functions import StepFunctions from metaflow.runner.deployer import DeployedFlow, TriggeredRun -from metaflow.runner.utils import get_lower_level_group, handle_timeout +from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo class StepFunctionsTriggeredRun(TriggeredRun): @@ -196,18 +196,14 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun: Exception If there is an error during the trigger process. """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) - + with temporary_fifo() as (attribute_file_path, attribute_file_fd): # every subclass needs to have `self.deployer_kwargs` command = get_lower_level_group( self.deployer.api, self.deployer.top_level_kwargs, self.deployer.TYPE, self.deployer.deployer_kwargs, - ).trigger(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) + ).trigger(deployer_attribute_file=attribute_file_path, **kwargs) pid = self.deployer.spm.run_command( [sys.executable, *command], @@ -218,7 +214,7 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun: command_obj = self.deployer.spm.get(pid) content = handle_timeout( - tfp_runner_attribute, command_obj, self.deployer.file_read_timeout + attribute_file_fd, command_obj, self.deployer.file_read_timeout ) if command_obj.process.returncode == 0: diff --git a/metaflow/runner/deployer_impl.py b/metaflow/runner/deployer_impl.py index 07e6cf51429..f9374ef8913 100644 --- a/metaflow/runner/deployer_impl.py +++ b/metaflow/runner/deployer_impl.py @@ -2,12 +2,11 @@ import json import os import sys -import tempfile from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type from .subprocess_manager import SubprocessManager -from .utils import get_lower_level_group, handle_timeout +from .utils import get_lower_level_group, handle_timeout, temporary_fifo if TYPE_CHECKING: import metaflow.runner.deployer @@ -121,14 +120,11 @@ def create(self, **kwargs) -> "metaflow.runner.deployer.DeployedFlow": def _create( self, create_class: Type["metaflow.runner.deployer.DeployedFlow"], **kwargs ) -> "metaflow.runner.deployer.DeployedFlow": - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) + with temporary_fifo() as (attribute_file_path, attribute_file_fd): # every subclass needs to have `self.deployer_kwargs` command = get_lower_level_group( self.api, self.top_level_kwargs, self.TYPE, self.deployer_kwargs - ).create(deployer_attribute_file=tfp_runner_attribute.name, **kwargs) + ).create(deployer_attribute_file=attribute_file_path, **kwargs) pid = self.spm.run_command( [sys.executable, *command], @@ -139,7 +135,7 @@ def _create( command_obj = self.spm.get(pid) content = handle_timeout( - tfp_runner_attribute, command_obj, self.file_read_timeout + attribute_file_fd, command_obj, self.file_read_timeout ) content = json.loads(content) self.name = content.get("name") diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 78418f49e8a..3a0d16552aa 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -2,13 +2,16 @@ import os import sys import json -import tempfile from typing import Dict, Iterator, Optional, Tuple from metaflow import Run -from .utils import handle_timeout +from .utils import ( + temporary_fifo, + handle_timeout, + async_handle_timeout, +) from .subprocess_manager import CommandManager, SubprocessManager @@ -267,9 +270,22 @@ def __enter__(self) -> "Runner": async def __aenter__(self) -> "Runner": return self - def __get_executing_run(self, tfp_runner_attribute, command_obj): - content = handle_timeout( - tfp_runner_attribute, command_obj, self.file_read_timeout + def __get_executing_run(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + content = json.loads(content) + pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id")) + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + run_object = Run( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingRun(self, command_obj, run_object) + + async def __async_get_executing_run(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout ) content = json.loads(content) pathspec = "%s/%s" % (content.get("flow_name"), content.get("run_id")) @@ -298,12 +314,9 @@ def run(self, **kwargs) -> ExecutingRun: ExecutingRun ExecutingRun containing the results of the run. """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) + with temporary_fifo() as (attribute_file_path, attribute_file_fd): command = self.api(**self.top_level_kwargs).run( - runner_attribute_file=tfp_runner_attribute.name, **kwargs + runner_attribute_file=attribute_file_path, **kwargs ) pid = self.spm.run_command( @@ -314,7 +327,7 @@ def run(self, **kwargs) -> ExecutingRun: ) command_obj = self.spm.get(pid) - return self.__get_executing_run(tfp_runner_attribute, command_obj) + return self.__get_executing_run(attribute_file_fd, command_obj) def resume(self, **kwargs): """ @@ -332,12 +345,9 @@ def resume(self, **kwargs): ExecutingRun ExecutingRun containing the results of the resumed run. """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) + with temporary_fifo() as (attribute_file_path, attribute_file_fd): command = self.api(**self.top_level_kwargs).resume( - runner_attribute_file=tfp_runner_attribute.name, **kwargs + runner_attribute_file=attribute_file_path, **kwargs ) pid = self.spm.run_command( @@ -348,7 +358,7 @@ def resume(self, **kwargs): ) command_obj = self.spm.get(pid) - return self.__get_executing_run(tfp_runner_attribute, command_obj) + return self.__get_executing_run(attribute_file_fd, command_obj) async def async_run(self, **kwargs) -> ExecutingRun: """ @@ -368,12 +378,9 @@ async def async_run(self, **kwargs) -> ExecutingRun: ExecutingRun ExecutingRun representing the run that was started. """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) + with temporary_fifo() as (attribute_file_path, attribute_file_fd): command = self.api(**self.top_level_kwargs).run( - runner_attribute_file=tfp_runner_attribute.name, **kwargs + runner_attribute_file=attribute_file_path, **kwargs ) pid = await self.spm.async_run_command( @@ -383,7 +390,7 @@ async def async_run(self, **kwargs) -> ExecutingRun: ) command_obj = self.spm.get(pid) - return self.__get_executing_run(tfp_runner_attribute, command_obj) + return await self.__async_get_executing_run(attribute_file_fd, command_obj) async def async_resume(self, **kwargs): """ @@ -403,12 +410,9 @@ async def async_resume(self, **kwargs): ExecutingRun ExecutingRun representing the resumed run that was started. """ - with tempfile.TemporaryDirectory() as temp_dir: - tfp_runner_attribute = tempfile.NamedTemporaryFile( - dir=temp_dir, delete=False - ) + with temporary_fifo() as (attribute_file_path, attribute_file_fd): command = self.api(**self.top_level_kwargs).resume( - runner_attribute_file=tfp_runner_attribute.name, **kwargs + runner_attribute_file=attribute_file_path, **kwargs ) pid = await self.spm.async_run_command( @@ -418,7 +422,7 @@ async def async_resume(self, **kwargs): ) command_obj = self.spm.get(pid) - return self.__get_executing_run(tfp_runner_attribute, command_obj) + return await self.__async_get_executing_run(attribute_file_fd, command_obj) def __exit__(self, exc_type, exc_value, traceback): self.spm.cleanup() diff --git a/metaflow/runner/subprocess_manager.py b/metaflow/runner/subprocess_manager.py index c8016244ea0..37e69ce06bf 100644 --- a/metaflow/runner/subprocess_manager.py +++ b/metaflow/runner/subprocess_manager.py @@ -9,26 +9,61 @@ import threading from typing import Callable, Dict, Iterator, List, Optional, Tuple +from .utils import check_process_exited -def kill_process_and_descendants(pid, termination_timeout): + +def kill_processes_and_descendants(pids: List[str], termination_timeout: float): # TODO: there's a race condition that new descendants might # spawn b/w the invocations of 'pkill' and 'kill'. # Needs to be fixed in future. try: - subprocess.check_call(["pkill", "-TERM", "-P", str(pid)]) - subprocess.check_call(["kill", "-TERM", str(pid)]) + subprocess.check_call(["pkill", "-TERM", "-P", *pids]) + subprocess.check_call(["kill", "-TERM", *pids]) except subprocess.CalledProcessError: pass time.sleep(termination_timeout) try: - subprocess.check_call(["pkill", "-KILL", "-P", str(pid)]) - subprocess.check_call(["kill", "-KILL", str(pid)]) + subprocess.check_call(["pkill", "-KILL", "-P", *pids]) + subprocess.check_call(["kill", "-KILL", *pids]) except subprocess.CalledProcessError: pass +async def async_kill_processes_and_descendants( + pids: List[str], termination_timeout: float +): + # TODO: there's a race condition that new descendants might + # spawn b/w the invocations of 'pkill' and 'kill'. + # Needs to be fixed in future. + try: + sub_term = await asyncio.create_subprocess_exec("pkill", "-TERM", "-P", *pids) + await sub_term.wait() + except Exception: + pass + + try: + main_term = await asyncio.create_subprocess_exec("kill", "-TERM", *pids) + await main_term.wait() + except Exception: + pass + + await asyncio.sleep(termination_timeout) + + try: + sub_kill = await asyncio.create_subprocess_exec("pkill", "-KILL", "-P", *pids) + await sub_kill.wait() + except Exception: + pass + + try: + main_kill = await asyncio.create_subprocess_exec("kill", "-KILL", *pids) + await main_kill.wait() + except Exception: + pass + + class LogReadTimeoutError(Exception): """Exception raised when reading logs times out.""" @@ -46,14 +81,28 @@ def __init__(self): loop = asyncio.get_running_loop() loop.add_signal_handler( signal.SIGINT, - lambda: self._handle_sigint(signum=signal.SIGINT, frame=None), + lambda: asyncio.create_task(self._async_handle_sigint()), ) except RuntimeError: signal.signal(signal.SIGINT, self._handle_sigint) + async def _async_handle_sigint(self): + pids = [ + str(command.process.pid) + for command in self.commands.values() + if command.process and not check_process_exited(command) + ] + if pids: + await async_kill_processes_and_descendants(pids, termination_timeout=2) + def _handle_sigint(self, signum, frame): - for each_command in self.commands.values(): - each_command.kill(termination_timeout=2) + pids = [ + str(command.process.pid) + for command in self.commands.values() + if command.process and not check_process_exited(command) + ] + if pids: + kill_processes_and_descendants(pids, termination_timeout=2) async def __aenter__(self) -> "SubprocessManager": return self @@ -472,7 +521,7 @@ def kill(self, termination_timeout: float = 2): """ if self.process is not None: - kill_process_and_descendants(self.process.pid, termination_timeout) + kill_processes_and_descendants([str(self.process.pid)], termination_timeout) else: print("No process to kill.") diff --git a/metaflow/runner/utils.py b/metaflow/runner/utils.py index 0ef202a3a5a..ec0f1865d7c 100644 --- a/metaflow/runner/utils.py +++ b/metaflow/runner/utils.py @@ -2,9 +2,11 @@ import ast import time import asyncio - +import tempfile +import select +from contextlib import contextmanager from subprocess import CalledProcessError -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, Dict, TYPE_CHECKING, ContextManager, Tuple if TYPE_CHECKING: import tempfile @@ -39,45 +41,194 @@ def format_flowfile(cell): return "\n".join(lines) -def check_process_status( +def check_process_exited( command_obj: "metaflow.runner.subprocess_manager.CommandManager", -): +) -> bool: if isinstance(command_obj.process, asyncio.subprocess.Process): return command_obj.process.returncode is not None else: return command_obj.process.poll() is not None -def read_from_file_when_ready( - file_path: str, +@contextmanager +def temporary_fifo() -> ContextManager[Tuple[str, int]]: + """ + Create and open the read side of a temporary FIFO in a non-blocking mode. + + Returns + ------- + str + Path to the temporary FIFO. + int + File descriptor of the temporary FIFO. + """ + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "fifo") + os.mkfifo(path) + # Blocks until the write side is opened unless in non-blocking mode + fd = os.open(path, os.O_RDONLY | os.O_NONBLOCK) + try: + yield path, fd + finally: + os.close(fd) + + +def read_from_fifo_when_ready( + fifo_fd: int, + command_obj: "metaflow.runner.subprocess_manager.CommandManager", + encoding: str = "utf-8", + timeout: int = 3600, +) -> str: + """ + Read the content from the FIFO file descriptor when it is ready. + + Parameters + ---------- + fifo_fd : int + File descriptor of the FIFO. + command_obj : CommandManager + Command manager object that handles the write side of the FIFO. + encoding : str, optional + Encoding to use while reading the file, by default "utf-8". + timeout : int, optional + Timeout for reading the file in milliseconds, by default 3600. + + Returns + ------- + str + Content read from the FIFO. + + Raises + ------ + TimeoutError + If no event occurs on the FIFO within the timeout. + CalledProcessError + If the process managed by `command_obj` has exited without writing any + content to the FIFO. + """ + content = bytearray() + + poll = select.poll() + poll.register(fifo_fd, select.POLLIN) + + while True: + poll_begin = time.time() + poll.poll(timeout) + timeout -= 1000 * (time.time() - poll_begin) + + if timeout <= 0: + raise TimeoutError("Timeout while waiting for the file content") + + try: + data = os.read(fifo_fd, 128) + while data: + content += data + data = os.read(fifo_fd, 128) + + # Read from a non-blocking closed FIFO returns an empty byte array + break + + except BlockingIOError: + # FIFO is open but no data is available yet + continue + + if not content and check_process_exited(command_obj): + raise CalledProcessError(command_obj.process.returncode, command_obj.command) + + return content.decode(encoding) + + +async def async_read_from_fifo_when_ready( + fifo_fd: int, + command_obj: "metaflow.runner.subprocess_manager.CommandManager", + encoding: str = "utf-8", + timeout: int = 3600, +) -> str: + """ + Read the content from the FIFO file descriptor when it is ready. + + Parameters + ---------- + fifo_fd : int + File descriptor of the FIFO. + command_obj : CommandManager + Command manager object that handles the write side of the FIFO. + encoding : str, optional + Encoding to use while reading the file, by default "utf-8". + timeout : int, optional + Timeout for reading the file in milliseconds, by default 3600. + + Returns + ------- + str + Content read from the FIFO. + + Raises + ------ + TimeoutError + If no event occurs on the FIFO within the timeout. + CalledProcessError + If the process managed by `command_obj` has exited without writing any + content to the FIFO. + """ + return await asyncio.to_thread( + read_from_fifo_when_ready, fifo_fd, command_obj, encoding, timeout + ) + + +def make_process_error_message( command_obj: "metaflow.runner.subprocess_manager.CommandManager", - timeout: float = 5, ): - start_time = time.time() - with open(file_path, "r", encoding="utf-8") as file_pointer: - content = file_pointer.read() - while not content: - if check_process_status(command_obj): - # Check to make sure the file hasn't been read yet to avoid a race - # where the file is written between the end of this while loop and the - # poll call above. - content = file_pointer.read() - if content: - break - raise CalledProcessError( - command_obj.process.returncode, command_obj.command - ) - if time.time() - start_time > timeout: - raise TimeoutError( - "Timeout while waiting for file content from '%s'" % file_path - ) - time.sleep(0.1) - content = file_pointer.read() - return content + stdout_log = open(command_obj.log_files["stdout"], encoding="utf-8").read() + stderr_log = open(command_obj.log_files["stderr"], encoding="utf-8").read() + command = " ".join(command_obj.command) + error_message = "Error executing: '%s':\n" % command + if stdout_log.strip(): + error_message += "\nStdout:\n%s\n" % stdout_log + if stderr_log.strip(): + error_message += "\nStderr:\n%s\n" % stderr_log + return error_message def handle_timeout( - tfp_runner_attribute: "tempfile._TemporaryFileWrapper[str]", + attribute_file_fd: int, + command_obj: "metaflow.runner.subprocess_manager.CommandManager", + file_read_timeout: int, +): + """ + Handle the timeout for a running subprocess command that reads a file + and raises an error with appropriate logs if a TimeoutError occurs. + + Parameters + ---------- + attribute_file_fd : int + File descriptor belonging to the FIFO containing the attribute data. + command_obj : CommandManager + Command manager object that encapsulates the running command details. + file_read_timeout : int + Timeout for reading the file. + + Returns + ------- + str + Content read from the temporary file. + + Raises + ------ + RuntimeError + If a TimeoutError occurs, it raises a RuntimeError with the command's + stdout and stderr logs. + """ + try: + return read_from_fifo_when_ready( + attribute_file_fd, command_obj=command_obj, timeout=file_read_timeout + ) + except (CalledProcessError, TimeoutError) as e: + raise RuntimeError(make_process_error_message(command_obj)) from e + + +async def async_handle_timeout( + attribute_file_fd: "int", command_obj: "metaflow.runner.subprocess_manager.CommandManager", file_read_timeout: int, ): @@ -87,8 +238,8 @@ def handle_timeout( Parameters ---------- - tfp_runner_attribute : NamedTemporaryFile - Temporary file that stores runner attribute data. + attribute_file_fd : int + File descriptor belonging to the FIFO containing the attribute data. command_obj : CommandManager Command manager object that encapsulates the running command details. file_read_timeout : int @@ -106,20 +257,11 @@ def handle_timeout( stdout and stderr logs. """ try: - content = read_from_file_when_ready( - tfp_runner_attribute.name, command_obj, timeout=file_read_timeout + return await async_read_from_fifo_when_ready( + attribute_file_fd, command_obj=command_obj, timeout=file_read_timeout ) - return content except (CalledProcessError, TimeoutError) as e: - stdout_log = open(command_obj.log_files["stdout"], encoding="utf-8").read() - stderr_log = open(command_obj.log_files["stderr"], encoding="utf-8").read() - command = " ".join(command_obj.command) - error_message = "Error executing: '%s':\n" % command - if stdout_log.strip(): - error_message += "\nStdout:\n%s\n" % stdout_log - if stderr_log.strip(): - error_message += "\nStderr:\n%s\n" % stderr_log - raise RuntimeError(error_message) from e + raise RuntimeError(make_process_error_message(command_obj)) from e def get_lower_level_group(