From 8180ea6b0e0f5896de0adef098da568e96d8ef67 Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Fri, 29 Jul 2022 06:38:52 +0000 Subject: [PATCH] Full Airflow integration with foreach and Sensors Squashed commit of the following: commit 43fadd01598dd41f66ad57dad1978181a449ba0f Author: Valay Dave Date: Fri Jul 29 00:59:01 2022 +0000 commiting k8s related file from master. commit 228b60cef41c80df310ebf4d7b34f3d3241b9c22 Author: Valay Dave Date: Fri Jul 29 00:54:45 2022 +0000 Uncommented code for foreach support with k8s KubernetesPodOperator version 4.2.0 renamed `resources` to `container_resources` - Check : (https://github.com/apache/airflow/pull/24673) / - (https://github.com/apache/airflow/commit/45f4290712f5f779e57034f81dbaab5d77d5de85) This was done because `KubernetesPodOperator` didn't play nice with dynamic task mapping and they had to deprecate the `resources` argument. Hence the below codepath checks for the version of `KubernetesPodOperator` and then sets the argument. If the version < 4.2.0 then we set the argument as `resources`. If it is > 4.2.0 then we set the argument as `container_resources` The `resources` argument of KuberentesPodOperator is going to be deprecated soon in the future. So we will only use it for `KuberentesPodOperator` version < 4.2.0 The `resources` argument will also not work for foreach's. commit 8e7ac882282511674c5272102a0d529a37808293 Author: Valay Dave Date: Mon Jul 18 18:31:58 2022 +0000 nit fixes : - fixing comments. - refactor some variable/function names. commit effca46b6837b1149feb73e37e2816a1f9f088d4 Author: Valay Dave Date: Mon Jul 18 18:14:53 2022 +0000 change `token` to `production_token` commit fd50ddc37b9d3694bd57e89e10a44c4a350d479d Author: Valay Dave Date: Mon Jul 18 18:11:56 2022 +0000 Refactored import Airflow Sensors. commit bd86fa00a030984c96577e92ae24595f6859dec1 Author: Valay Dave Date: Mon Jul 18 18:08:41 2022 +0000 new comment on `startup_timeout_seconds` env var. commit 1afd191a92b8ccdb22aef7853a5eb72e05d1a834 Author: Valay Dave Date: Mon Jul 18 18:06:09 2022 +0000 Removing traces of `@airflow_schedule_interval` commit 374b6330943b149836156e0aaed8de2935ceca7f Author: Valay Dave Date: Thu Jul 14 12:43:08 2022 -0700 Foreach polish (valayDave/metaflow#62) * Removing unused imports * Added validation logic for airflow version numbers with foreaches * Removed `airflow_schedule_interval` decorator. * Added production/deployment token related changes - Uses s3 as a backend to store the production token - Token used for avoiding nameclashes - token stored via `FlowDatastore` * Graph type validation for airflow foreachs - Airflow foreachs only support single node fanout. - validation invalidates graphs with nested foreachs * Added configuration about startup_timeout. * Added final todo on `resources` argument of k8sOp - added a commented code block - it needs to be uncommented when airflow releasese the patch for the op - Code seems feature complete keeping aside airflow patch commit 8d27ee769a066c72734f8e95f308bed4ba69978e Author: Valay Dave Date: Thu Jul 7 19:33:07 2022 +0000 Removed retries from user-defaults. commit cfa5a1508726bde402900fc554185811bd9100b6 Author: Valay Dave Date: Wed Jul 6 16:29:33 2022 +0000 updated pod startup time commit de0004224e2532c539b69b8c1fe64eea63dad56b Author: Valay Dave Date: Wed Jun 29 18:44:11 2022 +0000 Adding default 1 retry for any airflow worker. commit 63d7c12e635472c525eee1e5b5edc001dd8dc221 Author: Valay Dave Date: Mon Jun 27 01:22:42 2022 +0000 Airflow Foreach Integration - Simple one node foreach-join support as gaurenteed by airflow - Fixed env variable setting issue - introduced MetaflowKuberentesOperator - Created a new operator to allow smootness in plumbing xcom values - Some todos commit c3968839e23d82eb47728ddda55e43081fe91926 Author: Valay Dave Date: Fri Jun 24 21:12:09 2022 +0000 simplifying run-id macro. commit 396bdadd536c4e2bdcb708fc747109e00ffd89a6 Author: Valay Dave Date: Fri Jun 24 11:51:42 2022 -0700 Refactored parameter macro settings. (valayDave/metaflow#60) commit 0506d873449e132bda1670be06eaacc954d79128 Author: Valay Dave Date: Fri Jun 24 02:05:57 2022 +0000 added comment on need for `start_date` commit 726a5c26abef6d5ce3af4270ebad19bdcaed3440 Author: Valay Dave Date: Tue Jun 21 06:03:56 2022 +0000 Refactored an `id_creator` method. commit 7ef186f56cdedbd2d7d07cc07cd907038a5584f7 Author: Valay Dave Date: Tue Jun 21 05:52:05 2022 +0000 refactor : -`RUN_ID_LEN` to `RUN_HASH_ID_LEN` - `TASK_ID_LEN` to `TASK_ID_HASH_LEN` commit 04299e3eb46e71f2bcb59ed4916742f3fd07b3ee Author: Valay Dave Date: Tue Jun 21 05:48:55 2022 +0000 refactored an error string commit ad3dc1cfd7bbcf0ba1c299587a5a572898efa016 Author: Valay Dave Date: Mon Jun 20 22:42:36 2022 -0700 addressing savins comments. (#59) - Added many adhoc changes based for some comments. - Integrated secrets and `KUBERNETES_SECRETS` - cleaned up parameter setting - cleaned up setting of scheduling interval - renamed `AIRFLOW_TASK_ID_TEMPLATE_VALUE` to `AIRFLOW_TASK_ID` - renamed `AirflowSensorDecorator.compile` to `AirflowSensorDecorator.validate` - Checking if dagfile and flow file are same. - fixing variable names. - checking out `kubernetes_decorator.py` from master (6441ed5) - bug fixing secret setting in airflow. - simplified parameter type parsing logic - refactoring airflow argument parsing code. commit 8e68090c72899d9e46b3e9426dd675f088297fc1 Author: Valay Dave Date: Mon Jun 13 14:02:57 2022 -0700 Addressing Final comments. (#57) - Added dag-run timeout. - airflow related scheduling checks in decorator. - Auto naming sensors if no name is provided - Annotations to k8s operators - fix: argument serialization for `DAG` arguments (method names refactored like `to_dict` became `serialize`) - annotation bug fix - setting`workflow-timeout` for only scheduled dags commit d3ad82d93009003ce88a8fa91d3d86c0457ef5d6 Author: Valay Dave Date: Mon Jun 6 04:50:49 2022 +0000 k8s bug fix commit 483ffa80451c42e03a9f74674dbd93119e1fc9f3 Author: Valay Dave Date: Mon Jun 6 04:39:50 2022 +0000 removed un-used function commit 3782961c7fa44c812dc4cf7ac0e0217e8856ff31 Author: Valay Dave Date: Mon Jun 6 04:38:37 2022 +0000 Removed unused `sanitize_label` function commit 85d33910a3e2db90f43b1668b62246427f54bb02 Author: Valay Dave Date: Mon Jun 6 04:37:34 2022 +0000 GPU support added + container naming same as argo commit 9745d6fc388a9b77c2b77cf3e97d01e017d12aa8 Author: Valay Dave Date: Mon Jun 6 04:25:17 2022 +0000 Refactored sensors to different files + bug fix - bug caused due `util.compress_list`. - The function doesn't play nice with strings with variety of characters. - Ensured that exceptions are handled appropriately. - Made new file for each sensor under `airflow.sensors` module. commit bde3a0e2a77161d3c731de660314cf5454377cbd Author: Valay Dave Date: Sat Jun 4 01:41:49 2022 +0000 ran black. commit 360d2c5eefd86bd71439ae3dd199cc1ced126989 Author: Valay Dave Date: Fri Jun 3 18:32:48 2022 -0700 Moving information from airflow_utils to compiler (#56) - commenting todos to organize unfinished changes. - some environment variables set via`V1EnvVar` - `client.V1ObjectFieldSelector` mapped env vars were not working in json form - Moving k8s operator import into its own function. - env vars moved. commit 36f67a9636013199c31fc892ac7e10a6b809a282 Author: Valay Dave Date: Fri Jun 3 18:06:03 2022 +0000 added mising Run-id prefixes to variables. - merged `hash` and `dash_connect` filters. commit 5f4d5733556b9f93a8e2d67a2646b53c8ecbdb59 Author: Valay Dave Date: Fri Jun 3 18:00:22 2022 +0000 nit fix : variable name change. commit 4aee4ca8350fd18857c1d345f9e37f2ab5bc5a3f Author: Valay Dave Date: Fri Jun 3 17:58:34 2022 +0000 nit fixes to dag.py's templating variables. commit 7e6da56cf3d09cd5a3b8b196eadadf4f2ae54241 Author: Valay Dave Date: Fri Jun 3 17:56:53 2022 +0000 Fixed defaults passing - Addressed comments for airflow.py commit e9141821fbed0aafbd756a08b24628c8d6bd7e1e Author: Valay Dave Date: Fri Jun 3 17:52:24 2022 +0000 Following Changes: - Refactors setting scheduling interval - refactor dag file creating function - refactored is_active to is_paused_upon_creation - removed catchup commit 3998d4bba9a0cd573a57669a150fbc39a71de75a Author: Valay Dave Date: Fri Jun 3 17:33:25 2022 +0000 Multiple Changes based on comments: 1. refactored `create_k8s_args` into _to_job 2. Addressed comments for snake casing 3. refactored `attrs` for simplicity. 4. refactored `metaflow_parameters` to `parameters`. 5. Refactored setting of `input_paths` commit fa21dc219bee44a4036384499315fbd60f99272f Author: Valay Dave Date: Fri Jun 3 16:42:24 2022 +0000 Removed Sensor metadata extraction. commit 284608f77f26ede17fec71e39d8e3e2418914505 Author: Valay Dave Date: Fri Jun 3 16:30:34 2022 +0000 porting savin's comments - next changes : addressing comments. commit 41c2118afff4ee3f302b9eba451e407a82aabb18 Merge: 3f2353a ca29bec Author: Valay Dave Date: Fri Jul 29 06:36:27 2022 +0000 Merge branch 'master' into airflow commit 3f2353a647e53bc240e28792769c42a71ea8f8c9 Merge: d370ffb c1ff469 Author: Valay Dave Date: Thu Jul 28 23:52:16 2022 +0000 Merge branch 'master' into airflow commit d370ffb248411ad4675f9d55de709dbd75d3806e Merge: a82f144 e4eb751 Author: Valay Dave Date: Thu Jul 14 19:38:48 2022 +0000 Merge branch 'master' into airflow commit a82f1447b414171fc5611758cb6c12fc692f55f9 Merge: bdb1f0d 6f097e3 Author: Valay Dave Date: Wed Jul 13 00:35:49 2022 +0000 Merge branch 'master' into airflow commit bdb1f0dd248d01318d4a493c75b6f54248c7be64 Merge: 8511215 f9a4968 Author: Valay Dave Date: Wed Jun 29 18:44:51 2022 +0000 Merge branch 'master' into airflow commit 85112158cd352cb7de95a2262c011c6f43d98283 Author: Valay Dave Date: Tue Jun 21 02:53:11 2022 +0000 Bug fix from master merge. commit 90c06f12bb14eda51c6a641766c5f67d6763abaa Merge: 0fb73af 6441ed5 Author: Valay Dave Date: Mon Jun 20 21:20:20 2022 +0000 Merge branch 'master' into airflow commit 0fb73af8af9fca2875261e3bdd305a0daab1b229 Author: Valay Dave Date: Sat Jun 4 00:53:10 2022 +0000 squashing bugs after changes from master. commit 09c6ba779f6b1b6ef1d7ed5b1bb2be70ec76575d Merge: 7bdf662 ffff49b Author: Valay Dave Date: Sat Jun 4 00:20:38 2022 +0000 Merge branch 'master' into af-mmr commit 7bdf662e14966b929b8369c65d5bd3bbe5741937 Author: Valay Dave Date: Mon May 16 17:42:38 2022 -0700 Airflow sensor api (#3) * Fixed run-id setting - Change gaurentees that multiple dags triggered at same moment have unique run-id * added allow multiple in `Decorator` class * Airflow sensor integration. >> support added for : - ExternalTaskSensor - S3KeySensor - SqlSensor >> sensors allow multiple decorators >> sensors accept those arguments which are supported by airflow * Added `@airflow_schedule_interval` decorator * Fixing bug run-id related in env variable setting. commit 2604a29452e794354cf4c612f48bae7cf45856ee Author: Valay Dave Date: Thu Apr 21 18:26:59 2022 +0000 Addressed comments. commit 584e88b679fed7d6eec8ce564bf3707359170568 Author: Valay Dave Date: Wed Apr 20 03:33:55 2022 +0000 fixed printing bug commit 169ac1535e5567149d94749ddaf70264e882d62c Author: Valay Dave Date: Wed Apr 20 03:30:59 2022 +0000 Option help bug fix. commit 6f8489bcc3bd715b65d8a8554a0f3932dc78c6f5 Author: Valay Dave Date: Wed Apr 20 03:25:54 2022 +0000 variable renamemetaflow_specific_args commit 0c779abcd1d9574878da6de8183461b53e0da366 Merge: d299b13 5a61508 Author: Valay Dave Date: Wed Apr 20 03:23:10 2022 +0000 Merge branch 'airflow-tests' into airflow commit 5a61508e61583b567ef8d3fea04e049d74a6d973 Author: Valay Dave Date: Wed Apr 20 03:22:54 2022 +0000 Removing un-used code / resolved-todos. commit d030830f2543f489a1c4ebd17da1b47942f041d6 Author: Valay Dave Date: Wed Apr 20 03:06:03 2022 +0000 ran black, commit 2d1fc06e41cbe45ccfd46e03bc87b09c7a78da45 Merge: f2cb319 7921d13 Author: Valay Dave Date: Wed Apr 20 03:04:19 2022 +0000 Merge branch 'master' into airflow-tests commit d299b13ce38d027ab27ce23c9bbcc0f43b222cfa Merge: f2cb319 7921d13 Author: Valay Dave Date: Wed Apr 20 03:02:37 2022 +0000 Merge branch 'master' into airflow commit f2cb3197725f11520da0d49cbeef8de215c243eb Author: Valay Dave Date: Wed Apr 20 02:54:03 2022 +0000 reverting change. commit 05b9db9cf0fe8b40873b2b74e203b4fc82e7fea4 Author: Valay Dave Date: Wed Apr 20 02:47:41 2022 +0000 3 changes: - Removing s3 dep - remove uesless import - added `deployed_on` in dag file template commit c6afba95f5ec05acf7f33fd3228cffd784556e3b Author: Valay Dave Date: Fri Apr 15 22:50:52 2022 +0000 Fixed passing secrets with kubernetes. commit c3ce7e9faa5f7a23d309e2f66f778dbca85df22a Author: Valay Dave Date: Fri Apr 15 22:04:22 2022 +0000 Refactored code . - removed compute/k8s.py - Moved k8s code to airflow_compiler.py - ran isort to airflow_compiler.py commit d1c343dbbffbddbebd2aeda26d6846e595144e0b Author: Valay Dave Date: Fri Apr 15 18:02:25 2022 +0000 Added validations about: - un-supported decorators - foreach Changed where validations are done to not save the package. commit 7b19f8e66e278c75d836daf6a1c7ed2c607417ce Author: Valay Dave Date: Fri Apr 15 03:34:26 2022 +0000 Fixing mf log related bug - No double logging on metaflow. commit 4d1f6bf9bb32868c949d8c103c8fe44ea41b3f13 Author: Valay Dave Date: Thu Apr 14 03:10:51 2022 +0000 Removed usless code WRT project decorator. commit 5ad9a3949e351b0ac13f11df13446953932e8ffc Author: Valay Dave Date: Thu Apr 14 03:03:19 2022 +0000 Remove readme. commit 60cb6a79404efe2bcf9bf9a118a68f0b98c7d771 Author: Valay Dave Date: Thu Apr 14 03:02:38 2022 +0000 Made file path required arguement. commit 9f0dc1b2e01ee04b05620630f3a0ec04fe873a31 Author: Valay Dave Date: Thu Apr 14 03:01:07 2022 +0000 changed `--is-active`->`--is-paused-upon-creation` - dags are active by default. commit 5b98f937a62ee74de8aed8b0efde5045a28f068b Author: Valay Dave Date: Thu Apr 14 02:55:46 2022 +0000 shortened length of run-id and task-id hashes. commit e53426eaa4b156e8bd70ae7510c2e7c66745d101 Author: Valay Dave Date: Thu Apr 14 02:41:32 2022 +0000 Removing un-used args. commit 72cbbfc7424f9be415c22d9144b16a0953f15295 Author: Valay Dave Date: Thu Apr 14 02:39:59 2022 +0000 Moved exceptions to airflow compiler commit b2970ddaa86c393c8abb7f203f6507c386ecbe00 Author: Valay Dave Date: Thu Apr 14 02:33:02 2022 +0000 Changes based on PR comments: - removed airflow xcom push file , moved to decorator code - removed prefix configuration - nit fixes. commit 9e622bac5a75eb9e7a6594d8fa0e47f076634b44 Author: Valay Dave Date: Mon Apr 11 20:39:00 2022 +0000 Removing un-used code paths + code cleanup commit 7425f62cff2c9128eea785223ddeb40fa2d8f503 Author: Valay Dave Date: Mon Apr 11 19:45:04 2022 +0000 Fixing bug fix in schedule. commit eb775cbadd1d2d2c90f160a95a0f42c8ff0d7f4c Author: Valay Dave Date: Sun Apr 10 02:52:59 2022 +0000 Bug fixes WRT Kubernetes secrets + k8s deployments. - Fixing some error messages. - Added some comments. commit 04c92b92c312a4789d3c1e156f61ef57b08dba9f Author: Valay Dave Date: Sun Apr 10 01:20:53 2022 +0000 Added secrets support. commit 4a0a85dff77327640233767e567aee2b379ac13e Author: Valay Dave Date: Sun Apr 10 00:11:46 2022 +0000 Bug fix. commit af91099c0a30c26b58d58696a3ef697ec49a8503 Author: Valay Dave Date: Sun Apr 10 00:03:34 2022 +0000 bug fix. commit c17f04a253dfe6118e2779db79da9669aa2fcef2 Author: Valay Dave Date: Sat Apr 9 23:55:41 2022 +0000 Bug fix in active defaults. commit 0d372361297857076df6af235d1de7005ac1544a Author: Valay Dave Date: Sat Apr 9 23:54:02 2022 +0000 @project, @schedule, default active dag support. - Added a flag to allow setting dag as active on creation - Airflow compatible schedule interval - Project name fixes. commit 5c97b15cb11b5e8279befc5b14c239463750e9b7 Author: Valay Dave Date: Thu Apr 7 21:15:18 2022 +0000 Max workers and worker pool support. commit 9c973f2f44c3cb3a98e3e63f6e4dcef898bc8bf2 Author: Valay Dave Date: Thu Apr 7 19:34:33 2022 +0000 Adding exceptions for missing features. commit 2a946e2f083a34b4b6ed84c70aebf96b084ee8a2 Author: Valay Dave Date: Mon Mar 28 19:34:11 2022 +0000 2 changes : - removed hacky line - added support to directly throw dags in s3. commit e0772ec1bad473482c6fd19f8c5e8b9845303c0a Author: Valay Dave Date: Wed Mar 23 22:38:20 2022 +0000 fixing bugs in service account setting commit 874b94aeeabc664f12551864eff9d8fdc24dc37b Author: Valay Dave Date: Sun Mar 20 23:49:15 2022 +0000 Added support for Branching with Airflow - remove `next` function in `AirflowTask` - `AirflowTask`s has no knowledge of next tasks. - removed todos + added some todos - Graph construction on airflow side using graph_structure datastructure. - graph_structure comes from`FlowGraph.output_steps()[1]` commit 8e9f649bd8c51171c38a1e5af70a44a85e7009ca Author: Valay Dave Date: Sun Mar 20 02:33:04 2022 +0000 Added hacky line commit fd5db04cf0a81b14efda5eaf40cd9227e2bac0d3 Author: Valay Dave Date: Sun Mar 20 02:06:38 2022 +0000 Removed hacky line. commit 5b23eb7d8446bef71246d853b11edafa93c6ef95 Author: Valay Dave Date: Sun Mar 20 01:44:57 2022 +0000 Added support for Parameters. - Supporting int, str, bool, float, JSONType commit c9378e9b284657357ad2997f2b492bc2f4aaefac Author: Valay Dave Date: Sun Mar 20 00:14:10 2022 +0000 Removed todos + added some validation logic. commit 7250a44e1dea1da3464f6f71d0c5188bd314275a Author: Valay Dave Date: Sat Mar 19 23:45:15 2022 +0000 Fixing logs related change from master. commit d125978619ab666dcf96db330acdca40f41b7114 Merge: 8cdac53 7e210a2 Author: Valay Dave Date: Sat Mar 19 23:42:48 2022 +0000 Merge branch 'master' into aft-mm commit 8cdac53dd32648455e36955badb8e0ef7b95a2b3 Author: Valay Dave Date: Sat Mar 19 23:36:47 2022 +0000 making changes sync with master commit 5a93d9f5198c360b2a84ab13a86496986850953c Author: Valay Dave Date: Sat Mar 19 23:29:47 2022 +0000 Fixed bug when using catch + retry commit 62bc8dff68a6171b3b4222075a8e8ac109f65b4c Author: Valay Dave Date: Sat Mar 19 22:58:37 2022 +0000 Changed retry setting. commit 563a20036a2dfcc48101f680f29d4917d53aa247 Author: Valay Dave Date: Sat Mar 19 22:42:57 2022 +0000 Fixed setting `task_id` : - switch task-id from airflow job is to hash to "runid/stepname" - refactor xcom setting variables - added comments commit e2a1e502221dc603385263c82e2c068b9f055188 Author: Valay Dave Date: Sat Mar 19 17:51:59 2022 +0000 setting retry logic. commit a697b56052210c8f009b68772c902bbf77713202 Author: Valay Dave Date: Thu Mar 17 01:02:11 2022 +0000 Nit fix. commit 68f13beb17c7e73c0dddc142ef2418675a506439 Author: Valay Dave Date: Wed Mar 16 20:46:19 2022 +0000 Added @schedule support + readme commit 57bdde54f9ad2c8fe5513dbdb9fd02394664e234 Author: Valay Dave Date: Tue Mar 15 19:47:06 2022 +0000 Fixed setting run-id / task-id to labels in k8s - Fixed setting run-id has from cli macro - added hashing macro to ensure that jinja template set the correct run-id to k8s labels - commit 3d6c31917297d0be5f9915b13680fc415ddb4421 Author: Valay Dave Date: Tue Mar 15 05:39:04 2022 +0000 Got linear workflows working on airflow. - Still not feature complete as lots of args are still unfilled / lots of unknows - minor tweek in eks to ensure airflow is k8s compatible. - passing state around via xcom-push - HACK : AWS keys are passed in a shady way. : Reverse this soon. commit db074b8012f76d9d85225a4ceddb2cde8fefa0f4 Author: Valay Dave Date: Fri Mar 11 12:34:33 2022 -0800 Tweeks commit a9f0468c4721a2017f1b26eb8edcdd80aaa57203 Author: Valay Dave Date: Tue Mar 1 17:14:47 2022 -0800 some changes based on savin's comments. - Added changes to task datastore for different reason : (todo) Decouple these - Added comments to SFN for reference. - Airflow DAG is no longer dependent on metaflow commit f32d089cd3865927bc7510f24ba3418d859410b6 Author: Valay Dave Date: Wed Feb 23 00:54:17 2022 -0800 First version of dynamic dag compiler. - Not completely finished code - Creates generic .py file a JSON that is parsed to create Airflow DAG. - Currently only boiler plate to make a linear dag but doesn't execute anything. - Unfinished code. commit d2def665a86d6a6622d6076882c1c2d54044e773 Author: Valay Dave Date: Sat Feb 19 14:01:47 2022 -0800 more tweeks. commit b176311f166788cc3dfc93354a0c5045a4e6a3d4 Author: Valay Dave Date: Thu Feb 17 09:04:29 2022 -0800 commit 0 - unfinished code. --- metaflow/decorators.py | 41 +- metaflow/metaflow_config.py | 9 + metaflow/plugins/__init__.py | 13 +- metaflow/plugins/airflow/__init__.py | 0 metaflow/plugins/airflow/airflow.py | 705 ++++++++++++++++++ metaflow/plugins/airflow/airflow_cli.py | 404 ++++++++++ metaflow/plugins/airflow/airflow_decorator.py | 66 ++ metaflow/plugins/airflow/airflow_utils.py | 646 ++++++++++++++++ metaflow/plugins/airflow/dag.py | 9 + metaflow/plugins/airflow/exception.py | 12 + .../airflow/plumbing/set_parameters.py | 21 + metaflow/plugins/airflow/sensors/__init__.py | 9 + .../plugins/airflow/sensors/base_sensor.py | 74 ++ .../airflow/sensors/external_task_sensor.py | 94 +++ metaflow/plugins/airflow/sensors/s3_sensor.py | 26 + .../plugins/airflow/sensors/sql_sensor.py | 31 + 16 files changed, 2148 insertions(+), 12 deletions(-) create mode 100644 metaflow/plugins/airflow/__init__.py create mode 100644 metaflow/plugins/airflow/airflow.py create mode 100644 metaflow/plugins/airflow/airflow_cli.py create mode 100644 metaflow/plugins/airflow/airflow_decorator.py create mode 100644 metaflow/plugins/airflow/airflow_utils.py create mode 100644 metaflow/plugins/airflow/dag.py create mode 100644 metaflow/plugins/airflow/exception.py create mode 100644 metaflow/plugins/airflow/plumbing/set_parameters.py create mode 100644 metaflow/plugins/airflow/sensors/__init__.py create mode 100644 metaflow/plugins/airflow/sensors/base_sensor.py create mode 100644 metaflow/plugins/airflow/sensors/external_task_sensor.py create mode 100644 metaflow/plugins/airflow/sensors/s3_sensor.py create mode 100644 metaflow/plugins/airflow/sensors/sql_sensor.py diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 0042e71dcaa..9ab5d5d4aee 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -100,6 +100,8 @@ class Decorator(object): name = "NONAME" defaults = {} + # `allow_multiple` allows setting many decorators of the same type to a step/flow. + allow_multiple = False def __init__(self, attributes=None, statically_defined=False): self.attributes = self.defaults.copy() @@ -226,9 +228,6 @@ class MyDecorator(StepDecorator): pass them around with every lifecycle call. """ - # `allow_multiple` allows setting many decorators of the same type to a step. - allow_multiple = False - def step_init( self, flow, graph, step_name, decorators, environment, flow_datastore, logger ): @@ -374,12 +373,17 @@ def _base_flow_decorator(decofunc, *args, **kwargs): if isinstance(cls, type) and issubclass(cls, FlowSpec): # flow decorators add attributes in the class dictionary, # _flow_decorators. - if decofunc.name in cls._flow_decorators: + if decofunc.name in cls._flow_decorators and not decofunc.allow_multiple: raise DuplicateFlowDecoratorException(decofunc.name) else: - cls._flow_decorators[decofunc.name] = decofunc( - attributes=kwargs, statically_defined=True - ) + deco_instance = decofunc(attributes=kwargs, statically_defined=True) + if decofunc.allow_multiple: + if decofunc.name not in cls._flow_decorators: + cls._flow_decorators[decofunc.name] = [deco_instance] + else: + cls._flow_decorators[decofunc.name].append(deco_instance) + else: + cls._flow_decorators[decofunc.name] = deco_instance else: raise BadFlowDecoratorException(decofunc.name) return cls @@ -469,11 +473,26 @@ def _attach_decorators_to_step(step, decospecs): def _init_flow_decorators( flow, graph, environment, flow_datastore, metadata, logger, echo, deco_options ): + # Certain decorators can be specified multiple times and exist as lists in the _flow_decorators dictionary for deco in flow._flow_decorators.values(): - opts = {option: deco_options[option] for option in deco.options} - deco.flow_init( - flow, graph, environment, flow_datastore, metadata, logger, echo, opts - ) + if type(deco) == list: + for rd in deco: + opts = {option: deco_options[option] for option in rd.options} + rd.flow_init( + flow, + graph, + environment, + flow_datastore, + metadata, + logger, + echo, + opts, + ) + else: + opts = {option: deco_options[option] for option in deco.options} + deco.flow_init( + flow, graph, environment, flow_datastore, metadata, logger, echo, opts + ) def _init_step_decorators(flow, graph, environment, flow_datastore, logger): diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index b40b855b26a..ee1e74d9ba8 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -242,6 +242,15 @@ def from_conf(name, default=None): ) # +## +# Airflow Configuration +## +# This configuration sets `startup_timeout_seconds` in airflow's KubernetesPodOperator. +AIRFLOW_KUBERNETES_STARTUP_TIMEOUT = from_conf( + "METAFLOW_AIRFLOW_KUBERNETES_STARTUP_TIMEOUT_SECONDS", 60 * 60 +) + + ### # Conda configuration ### diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index b374e1a35b8..bb7aa4dbcc7 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -88,6 +88,7 @@ def get_plugin_cli(): from .aws.batch import batch_cli from .kubernetes import kubernetes_cli from .aws.step_functions import step_functions_cli + from .airflow import airflow_cli from .argo import argo_workflows_cli from .cards import card_cli from . import tag_cli @@ -98,6 +99,7 @@ def get_plugin_cli(): card_cli.cli, kubernetes_cli.cli, step_functions_cli.cli, + airflow_cli.cli, argo_workflows_cli.cli, tag_cli.cli, ] @@ -121,6 +123,7 @@ def get_plugin_cli(): from .conda.conda_step_decorator import CondaStepDecorator from .cards.card_decorator import CardDecorator from .frameworks.pytorch import PytorchParallelDecorator +from .airflow.airflow_decorator import AirflowInternalDecorator STEP_DECORATORS = [ @@ -137,6 +140,7 @@ def get_plugin_cli(): ParallelDecorator, PytorchParallelDecorator, InternalTestUnboundedForeachDecorator, + AirflowInternalDecorator, ArgoWorkflowsInternalDecorator, ] _merge_lists(STEP_DECORATORS, _ext_plugins["STEP_DECORATORS"], "name") @@ -161,7 +165,14 @@ def get_plugin_cli(): from .aws.step_functions.schedule_decorator import ScheduleDecorator from .project_decorator import ProjectDecorator -FLOW_DECORATORS = [CondaFlowDecorator, ScheduleDecorator, ProjectDecorator] + +from .airflow.sensors import SUPPORTED_SENSORS + +FLOW_DECORATORS = [ + CondaFlowDecorator, + ScheduleDecorator, + ProjectDecorator, +] + SUPPORTED_SENSORS _merge_lists(FLOW_DECORATORS, _ext_plugins["FLOW_DECORATORS"], "name") # Cards diff --git a/metaflow/plugins/airflow/__init__.py b/metaflow/plugins/airflow/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py new file mode 100644 index 00000000000..99043863dd1 --- /dev/null +++ b/metaflow/plugins/airflow/airflow.py @@ -0,0 +1,705 @@ +from io import BytesIO +import json +import os +import random +import string +import sys +from datetime import datetime, timedelta + +import metaflow.util as util +from metaflow.decorators import flow_decorators +from metaflow.exception import MetaflowException +from metaflow.metaflow_config import ( + BATCH_METADATA_SERVICE_HEADERS, + BATCH_METADATA_SERVICE_URL, + DATASTORE_CARD_S3ROOT, + DATASTORE_SYSROOT_S3, + DATATOOLS_S3ROOT, + KUBERNETES_SERVICE_ACCOUNT, + KUBERNETES_SECRETS, + AIRFLOW_KUBERNETES_STARTUP_TIMEOUT, +) +from metaflow.parameters import deploy_time_eval +from metaflow.plugins.kubernetes.kubernetes import Kubernetes + +# TODO: Move chevron to _vendor +from metaflow.plugins.cards.card_modules import chevron +from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task +from metaflow.util import dict_to_cli_options, get_username, compress_list +from metaflow.parameters import JSONTypeClass + +from . import airflow_utils +from .exception import AirflowException +from .sensors import SUPPORTED_SENSORS +from .airflow_utils import ( + TASK_ID_XCOM_KEY, + AirflowTask, + Workflow, + AIRFLOW_MACROS, +) +from metaflow import current + +AIRFLOW_DEPLOY_TEMPLATE_FILE = os.path.join(os.path.dirname(__file__), "dag.py") + + +class Airflow(object): + + TOKEN_STORAGE_ROOT = "mf.airflow" + + def __init__( + self, + name, + graph, + flow, + code_package_sha, + code_package_url, + metadata, + flow_datastore, + environment, + event_logger, + monitor, + production_token, + tags=None, + namespace=None, + username=None, + max_workers=None, + worker_pool=None, + description=None, + file_path=None, + workflow_timeout=None, + is_paused_upon_creation=True, + ): + self.name = name + self.graph = graph + self.flow = flow + self.code_package_sha = code_package_sha + self.code_package_url = code_package_url + self.metadata = metadata + self.flow_datastore = flow_datastore + self.environment = environment + self.event_logger = event_logger + self.monitor = monitor + self.tags = tags + self.namespace = namespace # this is the username space + self.username = username + self.max_workers = max_workers + self.description = description + self._depends_on_upstream_sensors = False + self._file_path = file_path + _, self.graph_structure = self.graph.output_steps() + self.worker_pool = worker_pool + self.is_paused_upon_creation = is_paused_upon_creation + self.workflow_timeout = workflow_timeout + self.schedule = self._get_schedule() + self.parameters = self._process_parameters() + self.production_token = production_token + + @classmethod + def get_existing_deployment(cls, name, flow_datastore): + _backend = flow_datastore._storage_impl + token_paths = _backend.list_content([cls.get_token_path(name)]) + if len(token_paths) == 0: + return None + + with _backend.load_bytes([token_paths[0]]) as get_results: + for _, path, _ in get_results: + if path is not None: + with open(path, "r") as f: + data = json.loads(f.read()) + return (data["owner"], data["production_token"]) + + @classmethod + def get_token_path(cls, name): + return os.path.join(cls.TOKEN_STORAGE_ROOT, name) + + @classmethod + def save_deployment_token(cls, owner, token, flow_datastore): + _backend = flow_datastore._storage_impl + _backend.save_bytes( + [ + ( + cls.get_token_path(token), + BytesIO( + bytes( + json.dumps({"production_token": token, "owner": owner}), + "utf-8", + ) + ), + ) + ], + overwrite=False, + ) + + def _get_schedule(self): + # Using the cron presets provided here : + # https://airflow.apache.org/docs/apache-airflow/stable/dag-run.html?highlight=schedule%20interval#cron-presets + schedule = self.flow._flow_decorators.get("schedule") + if not schedule: + return None + if schedule.attributes["cron"]: + return schedule.attributes["cron"] + elif schedule.attributes["weekly"]: + return "@weekly" + elif schedule.attributes["hourly"]: + return "@hourly" + elif schedule.attributes["daily"]: + return "@daily" + return None + + def _get_retries(self, node): + max_user_code_retries = 0 + max_error_retries = 0 + foreach_default_retry = 1 + # Different decorators may have different retrying strategies, so take + # the max of them. + for deco in node.decorators: + user_code_retries, error_retries = deco.step_task_retry_count() + max_user_code_retries = max(max_user_code_retries, user_code_retries) + max_error_retries = max(max_error_retries, error_retries) + parent_is_foreach = any( # The immediate parent is a foreach node. + self.graph[n].type == "foreach" for n in node.in_funcs + ) + + if parent_is_foreach: + max_user_code_retries + foreach_default_retry + return max_user_code_retries, max_user_code_retries + max_error_retries + + def _get_retry_delay(self, node): + retry_decos = [deco for deco in node.decorators if deco.name == "retry"] + if len(retry_decos) > 0: + retry_mins = retry_decos[0].attributes["minutes_between_retries"] + return timedelta(minutes=int(retry_mins)) + return None + + def _process_parameters(self): + seen = set() + airflow_params = [] + type_transform_dict = { + int.__name__: "integer", + str.__name__: "string", + bool.__name__: "string", + float.__name__: "number", + JSONTypeClass.name: "string", + } + type_parser = {bool.__name__: lambda v: str(v)} + + for var, param in self.flow._get_parameters(): + # Throw an exception if the parameter is specified twice. + norm = param.name.lower() + if norm in seen: + raise MetaflowException( + "Parameter *%s* is specified twice. " + "Note that parameter names are " + "case-insensitive." % param.name + ) + seen.add(norm) + + # Airflow requires defaults set for parameters. + if "default" not in param.kwargs: + raise MetaflowException( + "Parameter *%s* does not have a " + "default value. " + "A default value is required for parameters when deploying flows on Airflow." + ) + value = deploy_time_eval(param.kwargs.get("default")) + # Setting airflow related param args. + param_type = param.kwargs.get("type") + airflow_param = dict( + name=param.name, + ) + if value is not None: + airflow_param["default"] = value + if param.kwargs.get("help"): + airflow_param["description"] = param.kwargs.get("help") + if param_type is not None: + # Todo (fast-follow): Check if we can support more `click.Param` types + param_type_name = getattr(param_type, "__name__", None) + if not param_type_name and isinstance(param_type, JSONTypeClass): + # `JSONTypeClass` has no __name__ attribute so we need to explicitly check if + # `param_type` is an instance of `JSONTypeClass`` + param_type_name = param_type.name + + if param_type_name in type_transform_dict: + airflow_param["type"] = type_transform_dict[param_type_name] + if param_type_name in type_parser and value is not None: + airflow_param["default"] = type_parser[param_type_name](value) + + airflow_params.append(airflow_param) + + return airflow_params + + def _compress_input_path( + self, + steps, + ): + """ + This function is meant to compress the input paths and it specifically doesn't use + `metaflow.util.compress_list` under the hood. The reason is because the `AIRFLOW_MACROS.RUN_ID` is a complicated macro string + that doesn't behave nicely with `metaflow.util.decompress_list` since the `decompress_util` + function expects a string which doesn't contain any delimiter characters and the run-id string does. + Hence we have a custom compression string created via `_compress_input_path` function instead of `compress_list`. + """ + return "%s:" % (AIRFLOW_MACROS.RUN_ID) + ",".join( + self._make_input_path(step, only_task_id=True) for step in steps + ) + + def _make_foreach_input_path(self, step_name): + return ( + "%s/%s/:{{ task_instance.xcom_pull(task_ids='%s',key='%s') | join_list }}" + % ( + AIRFLOW_MACROS.RUN_ID, + step_name, + step_name, + TASK_ID_XCOM_KEY, + ) + ) + + def _make_input_path(self, step_name, only_task_id=False): + """ + This is set using the `airflow_internal` decorator to help pass state. + This will pull the `TASK_ID_XCOM_KEY` xcom which holds task-ids. The key is set via the `MetaflowKubernetesOperator`. + """ + task_id_string = "/%s/{{ task_instance.xcom_pull(task_ids='%s',key='%s') }}" % ( + step_name, + step_name, + TASK_ID_XCOM_KEY, + ) + + if only_task_id: + return task_id_string + + return "%s%s" % (AIRFLOW_MACROS.RUN_ID, task_id_string) + + def _to_job(self, node): + """ + This function will transform the node's specification into Airflow compatible operator arguments. + Since this function is long, below is the summary of the two major duties it performs: + 1. Based on the type of the graph node (start/linear/foreach/join etc.) it will decide how to set the input paths + 2. Based on node's decorator specification convert the information into a job spec for the KubernetesPodOperator. + """ + # Add env vars from the optional @environment decorator. + env_deco = [deco for deco in node.decorators if deco.name == "environment"] + env = {} + if env_deco: + env = env_deco[0].attributes["vars"] + + # The below if/else block handles "input paths". + # Input Paths help manage dataflow across the graph. + if node.name == "start": + # POSSIBLE_FUTURE_IMPROVEMENT: + # We can extract metadata about the possible upstream sensor triggers. + # There is a previous commit (7bdf6) in the `airflow` branch that has `SensorMetaExtractor` class and + # associated MACRO we have built to handle this case if a metadata regarding the sensor is needed. + # Initialize parameters for the flow in the `start` step. + # `start` step has no upstream input dependencies aside from + # parameters. + + if len(self.parameters): + env["METAFLOW_PARAMETERS"] = AIRFLOW_MACROS.PARAMETERS + input_paths = None + else: + # If it is not the start node then we check if there are many paths + # converging into it or a single path. Based on that we set the INPUT_PATHS + if node.parallel_foreach: + raise AirflowException( + "Parallel steps are not supported yet with Airflow." + ) + is_foreach_join = ( + node.type == "join" + and self.graph[node.split_parents[-1]].type == "foreach" + ) + if is_foreach_join: + input_paths = self._make_foreach_input_path(node.in_funcs[0]) + + elif len(node.in_funcs) == 1: + # set input paths where this is only one parent node + # The parent-task-id is passed via the xcom; There is no other way to get that. + # One key thing about xcoms is that they are immutable and only accepted if the task + # doesn't fail. + # From airflow docs : + # "Note: If the first task run is not succeeded then on every retry task XComs will be cleared to make the task run idempotent." + input_paths = self._make_input_path(node.in_funcs[0]) + else: + # this is a split scenario where there can be more than one input paths. + input_paths = self._compress_input_path(node.in_funcs) + + # env["METAFLOW_INPUT_PATHS"] = input_paths + + env["METAFLOW_CODE_URL"] = self.code_package_url + env["METAFLOW_FLOW_NAME"] = self.flow.name + env["METAFLOW_STEP_NAME"] = node.name + env["METAFLOW_OWNER"] = self.username + + metadata_env = self.metadata.get_runtime_environment("airflow") + env.update(metadata_env) + + metaflow_version = self.environment.get_environment_info() + metaflow_version["flow_name"] = self.graph.name + metaflow_version["production_token"] = self.production_token + env["METAFLOW_VERSION"] = json.dumps(metaflow_version) + + # Extract the k8s decorators for constructing the arguments of the K8s Pod Operator on Airflow. + k8s_deco = [deco for deco in node.decorators if deco.name == "kubernetes"][0] + user_code_retries, _ = self._get_retries(node) + retry_delay = self._get_retry_delay(node) + # This sets timeouts for @timeout decorators. + # The timeout is set as "execution_timeout" for an airflow task. + runtime_limit = get_run_time_limit_for_task(node.decorators) + + k8s = Kubernetes(self.flow_datastore, self.metadata, self.environment) + user = util.get_username() + + labels = { + "app": "metaflow", + "app.kubernetes.io/name": "metaflow-task", + "app.kubernetes.io/part-of": "metaflow", + "app.kubernetes.io/created-by": user, + # Question to (savin) : Should we have username set over here for created by since it is the airflow installation that is creating the jobs. + # Technically the "user" is the stakeholder but should these labels be present. + } + additional_mf_variables = { + "METAFLOW_CODE_SHA": self.code_package_sha, + "METAFLOW_CODE_URL": self.code_package_url, + "METAFLOW_CODE_DS": self.flow_datastore.TYPE, + "METAFLOW_USER": user, + "METAFLOW_SERVICE_URL": BATCH_METADATA_SERVICE_URL, + "METAFLOW_SERVICE_HEADERS": json.dumps(BATCH_METADATA_SERVICE_HEADERS), + "METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3, + "METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT, + "METAFLOW_DEFAULT_DATASTORE": "s3", + "METAFLOW_DEFAULT_METADATA": "service", + "METAFLOW_KUBERNETES_WORKLOAD": str( + 1 + ), # This is used by kubernetes decorator. + "METAFLOW_RUNTIME_ENVIRONMENT": "kubernetes", + "METAFLOW_CARD_S3ROOT": DATASTORE_CARD_S3ROOT, + "METAFLOW_RUN_ID": AIRFLOW_MACROS.RUN_ID, + "METAFLOW_AIRFLOW_TASK_ID": AIRFLOW_MACROS.TASK_ID, + "METAFLOW_AIRFLOW_DAG_RUN_ID": AIRFLOW_MACROS.AIRFLOW_RUN_ID, + "METAFLOW_AIRFLOW_JOB_ID": AIRFLOW_MACROS.AIRFLOW_JOB_ID, + "METAFLOW_PRODUCTION_TOKEN": self.production_token, + "METAFLOW_ATTEMPT_NUMBER": AIRFLOW_MACROS.ATTEMPT, + } + env.update(additional_mf_variables) + service_account = ( + KUBERNETES_SERVICE_ACCOUNT + if k8s_deco.attributes["service_account"] is None + else k8s_deco.attributes["service_account"] + ) + k8s_namespace = ( + k8s_deco.attributes["namespace"] + if k8s_deco.attributes["namespace"] is not None + else "default" + ) + + resources = dict( + requests={ + "cpu": k8s_deco.attributes["cpu"], + "memory": "%sM" % str(k8s_deco.attributes["memory"]), + "ephemeral-storage": str(k8s_deco.attributes["disk"]), + } + ) + if k8s_deco.attributes["gpu"] is not None: + resources.update( + dict( + limits={ + "%s.com/gpu".lower() + % k8s_deco.attributes["gpu_vendor"]: str( + k8s_deco.attributes["gpu"] + ) + } + ) + ) + + annotations = { + "metaflow/production_token": self.production_token, + "metaflow/owner": self.username, + "metaflow/user": self.username, + "metaflow/flow_name": self.flow.name, + } + if current.get("project_name"): + annotations.update( + { + "metaflow/project_name": current.project_name, + "metaflow/branch_name": current.branch_name, + "metaflow/project_flow_name": current.project_flow_name, + } + ) + + k8s_operator_args = dict( + # like argo workflows we use step_name as name of container + name=node.name, + namespace=k8s_namespace, + service_account_name=service_account, + node_selector=k8s_deco.attributes["node_selector"], + cmds=k8s._command( + self.flow.name, + AIRFLOW_MACROS.RUN_ID, + node.name, + AIRFLOW_MACROS.TASK_ID, + AIRFLOW_MACROS.ATTEMPT, + code_package_url=self.code_package_url, + step_cmds=self._step_cli( + node, input_paths, self.code_package_url, user_code_retries + ), + ), + annotations=annotations, + image=k8s_deco.attributes["image"], + resources=resources, + execution_timeout=dict(seconds=runtime_limit), + retries=user_code_retries, + env_vars=[dict(name=k, value=v) for k, v in env.items()], + labels=labels, + task_id=node.name, + startup_timeout_seconds=AIRFLOW_KUBERNETES_STARTUP_TIMEOUT, + in_cluster=True, + get_logs=True, + do_xcom_push=True, + log_events_on_failure=True, + is_delete_operator_pod=True, + retry_exponential_backoff=False, # todo : should this be a arg we allow on CLI. not right now - there is an open ticket for this - maybe at some point we will. + reattach_on_restart=False, + secrets=[], + ) + if k8s_deco.attributes["secrets"]: + if isinstance(k8s_deco.attributes["secrets"], str): + k8s_operator_args["secrets"] = k8s_deco.attributes["secrets"].split(",") + elif isinstance(k8s_deco.attributes["secrets"], list): + k8s_operator_args["secrets"] = k8s_deco.attributes["secrets"] + if len(KUBERNETES_SECRETS) > 0: + k8s_operator_args["secrets"] += KUBERNETES_SECRETS.split(",") + + if retry_delay: + k8s_operator_args["retry_delay"] = dict(seconds=retry_delay.total_seconds()) + + return k8s_operator_args + + def _step_cli(self, node, paths, code_package_url, user_code_retries): + cmds = [] + + script_name = os.path.basename(sys.argv[0]) + executable = self.environment.executable(node.name) + + entrypoint = [executable, script_name] + + top_opts_dict = { + "with": [ + decorator.make_decorator_spec() + for decorator in node.decorators + if not decorator.statically_defined + ] + } + # FlowDecorators can define their own top-level options. They are + # responsible for adding their own top-level options and values through + # the get_top_level_options() hook. See similar logic in runtime.py. + for deco in flow_decorators(): + top_opts_dict.update(deco.get_top_level_options()) + + top_opts = list(dict_to_cli_options(top_opts_dict)) + + top_level = top_opts + [ + "--quiet", + "--metadata=%s" % self.metadata.TYPE, + "--environment=%s" % self.environment.TYPE, + "--datastore=%s" % self.flow_datastore.TYPE, + "--datastore-root=%s" % self.flow_datastore.datastore_root, + "--event-logger=%s" % self.event_logger.TYPE, + "--monitor=%s" % self.monitor.TYPE, + "--no-pylint", + "--with=airflow_internal", + ] + + if node.name == "start": + # We need a separate unique ID for the special _parameters task + task_id_params = "%s-params" % AIRFLOW_MACROS.TASK_ID + # Export user-defined parameters into runtime environment + param_file = "".join( + random.choice(string.ascii_lowercase) for _ in range(10) + ) + # Setup Parameters as environment variables which are stored in a dictionary. + export_params = ( + "python -m " + "metaflow.plugins.airflow.plumbing.set_parameters %s " + "&& . `pwd`/%s" % (param_file, param_file) + ) + # Setting parameters over here. + params = ( + entrypoint + + top_level + + [ + "init", + "--run-id %s" % AIRFLOW_MACROS.RUN_ID, + "--task-id %s" % task_id_params, + ] + ) + + # Assign tags to run objects. + if self.tags: + params.extend("--tag %s" % tag for tag in self.tags) + + # If the start step gets retried, we must be careful not to + # regenerate multiple parameters tasks. Hence we check first if + # _parameters exists already. + exists = entrypoint + [ + # Dump the parameters task + "dump", + "--max-value-size=0", + "%s/_parameters/%s" % (AIRFLOW_MACROS.RUN_ID, task_id_params), + ] + cmd = "if ! %s >/dev/null 2>/dev/null; then %s && %s; fi" % ( + " ".join(exists), + export_params, + " ".join(params), + ) + cmds.append(cmd) + # set input paths for parameters + paths = "%s/_parameters/%s" % (AIRFLOW_MACROS.RUN_ID, task_id_params) + + step = [ + "step", + node.name, + "--run-id %s" % AIRFLOW_MACROS.RUN_ID, + "--task-id %s" % AIRFLOW_MACROS.TASK_ID, + "--retry-count %s" % AIRFLOW_MACROS.ATTEMPT, + "--max-user-code-retries %d" % user_code_retries, + "--input-paths %s" % paths, + ] + if self.tags: + step.extend("--tag %s" % tag for tag in self.tags) + if self.namespace is not None: + step.append("--namespace=%s" % self.namespace) + + parent_is_foreach = any( # The immediate parent is a foreach node. + self.graph[n].type == "foreach" for n in node.in_funcs + ) + if parent_is_foreach: + step.append("--split-index %s" % AIRFLOW_MACROS.FOREACH_SPLIT_INDEX) + + cmds.append(" ".join(entrypoint + top_level + step)) + return cmds + + def _collect_flow_sensors(self): + decos_lists = [ + self.flow._flow_decorators.get(s.name) + for s in SUPPORTED_SENSORS + if self.flow._flow_decorators.get(s.name) is not None + ] + af_tasks = [deco.create_task() for decos in decos_lists for deco in decos] + if len(af_tasks) > 0: + self._depends_on_upstream_sensors = True + return af_tasks + + def _contains_foreach(self): + for node in self.graph: + if node.type == "foreach": + return True + return False + + def compile(self): + # Visit every node of the flow and recursively build the state machine. + def _visit(node, workflow, exit_node=None): + if node.parallel_foreach: + raise AirflowException( + "Deploying flows with @parallel decorator(s) " + "to Airflow is not supported currently." + ) + parent_is_foreach = any( # Any immediate parent is a foreach node. + self.graph[n].type == "foreach" for n in node.in_funcs + ) + state = AirflowTask( + node.name, is_mapper_node=parent_is_foreach + ).set_operator_args(**self._to_job(node)) + if node.type == "end": + workflow.add_state(state) + + # Continue linear assignment within the (sub)workflow if the node + # doesn't branch or fork. + elif node.type in ("start", "linear", "join", "foreach"): + workflow.add_state(state) + _visit( + self.graph[node.out_funcs[0]], + workflow, + ) + + elif node.type == "split": + workflow.add_state(state) + for func in node.out_funcs: + _visit( + self.graph[func], + workflow, + ) + else: + raise AirflowException( + "Node type *%s* for step *%s* " + "is not currently supported by " + "Airflow." % (node.type, node.name) + ) + + return workflow + + # set max active tasks here , For more info check here : + # https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/models/dag/index.html#airflow.models.dag.DAG + airflow_dag_args = ( + {} if self.max_workers is None else dict(max_active_tasks=self.max_workers) + ) + airflow_dag_args["is_paused_upon_creation"] = self.is_paused_upon_creation + + # workflow timeout should only be enforced if a dag is scheduled. + if self.workflow_timeout is not None and self.schedule is not None: + airflow_dag_args["dagrun_timeout"] = dict(seconds=self.workflow_timeout) + + appending_sensors = self._collect_flow_sensors() + workflow = Workflow( + dag_id=self.name, + default_args=self._create_defaults(), + description=self.description, + schedule_interval=self.schedule, + # `start_date` is a mandatory argument even though the documentation lists it as optional value + # Based on the code, Airflow will throw a `AirflowException` when `start_date` is not provided + # to a DAG : https://github.com/apache/airflow/blob/0527a0b6ce506434a23bc2a6f5ddb11f492fc614/airflow/models/dag.py#L2170 + start_date=datetime.now(), + tags=self.tags, + file_path=self._file_path, + graph_structure=self.graph_structure, + metadata=dict(contains_foreach=self._contains_foreach()), + **airflow_dag_args + ) + workflow = _visit(self.graph["start"], workflow) + + workflow.set_parameters(self.parameters) + if len(appending_sensors) > 0: + for s in appending_sensors: + workflow.add_state(s) + workflow.graph_structure.insert(0, [[s.name] for s in appending_sensors]) + return self._to_airflow_dag_file(workflow.to_dict()) + + def _to_airflow_dag_file(self, json_dag): + util_file = None + with open(airflow_utils.__file__) as f: + util_file = f.read() + with open(AIRFLOW_DEPLOY_TEMPLATE_FILE) as f: + return chevron.render( + f.read(), + dict( + # Converting the configuration to base64 so that there can be no indentation related issues that can be caused because of + # malformed strings / json. + config=json_dag, + utils=util_file, + deployed_on=str(datetime.now()), + ), + ) + + def _create_defaults(self): + defu_ = { + "owner": get_username(), + # If set on a task, doesn’t run the task in the current DAG run if the previous run of the task has failed. + "depends_on_past": False, + # TODO: Enable emails + "execution_timeout": timedelta(days=5), + "retry_delay": timedelta(seconds=200), + # check https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/models/baseoperator/index.html?highlight=retry_delay#airflow.models.baseoperator.BaseOperatorMeta + } + if self.worker_pool is not None: + defu_["pool"] = self.worker_pool + + return defu_ diff --git a/metaflow/plugins/airflow/airflow_cli.py b/metaflow/plugins/airflow/airflow_cli.py new file mode 100644 index 00000000000..39a904c2ab3 --- /dev/null +++ b/metaflow/plugins/airflow/airflow_cli.py @@ -0,0 +1,404 @@ +import os +import re +import sys +import base64 +from metaflow import current, decorators +from metaflow._vendor import click +from metaflow.exception import MetaflowException, MetaflowInternalError +from metaflow.package import MetaflowPackage +from hashlib import sha1 +from metaflow.plugins import KubernetesDecorator +from metaflow.util import get_username, to_bytes, to_unicode + +from .airflow import Airflow +from .exception import AirflowException, NotSupportedException + +from metaflow.plugins.aws.step_functions.production_token import ( + load_token, + new_token, + store_token, +) + + +class IncorrectProductionToken(MetaflowException): + headline = "Incorrect production token" + + +VALID_NAME = re.compile("[^a-zA-Z0-9_\-\.]") + + +def resolve_token( + name, token_prefix, obj, authorize, given_token, generate_new_token, is_project +): + # 1) retrieve the previous deployment, if one exists + + workflow = Airflow.get_existing_deployment(name, obj.flow_datastore) + if workflow is None: + obj.echo( + "It seems this is the first time you are deploying *%s* to " + "Airflow." % name + ) + prev_token = None + else: + prev_user, prev_token = workflow + + # 2) authorize this deployment + if prev_token is not None: + if authorize is None: + authorize = load_token(token_prefix) + elif authorize.startswith("production:"): + authorize = authorize[11:] + + # we allow the user who deployed the previous version to re-deploy, + # even if they don't have the token + if prev_user != get_username() and authorize != prev_token: + obj.echo( + "There is an existing version of *%s* on Airflow which was " + "deployed by the user *%s*." % (name, prev_user) + ) + obj.echo( + "To deploy a new version of this flow, you need to use the same " + "production token that they used. " + ) + obj.echo( + "Please reach out to them to get the token. Once you have it, call " + "this command:" + ) + obj.echo(" airflow create --authorize MY_TOKEN", fg="green") + obj.echo( + 'See "Organizing Results" at docs.metaflow.org for more information ' + "about production tokens." + ) + raise IncorrectProductionToken( + "Try again with the correct production token." + ) + + # 3) do we need a new token or should we use the existing token? + if given_token: + if is_project: + # we rely on a known prefix for @project tokens, so we can't + # allow the user to specify a custom token with an arbitrary prefix + raise MetaflowException( + "--new-token is not supported for @projects. Use --generate-new-token " + "to create a new token." + ) + if given_token.startswith("production:"): + given_token = given_token[11:] + token = given_token + obj.echo("") + obj.echo("Using the given token, *%s*." % token) + elif prev_token is None or generate_new_token: + token = new_token(token_prefix, prev_token) + if token is None: + if prev_token is None: + raise MetaflowInternalError( + "We could not generate a new token. This is unexpected. " + ) + else: + raise MetaflowException( + "--generate-new-token option is not supported after using " + "--new-token. Use --new-token to make a new namespace." + ) + obj.echo("") + obj.echo("A new production token generated.") + Airflow.save_deployment_token(get_username(), token, obj.flow_datastore) + else: + token = prev_token + + obj.echo("") + obj.echo("The namespace of this production flow is") + obj.echo(" production:%s" % token, fg="green") + obj.echo( + "To analyze results of this production flow add this line in your notebooks:" + ) + obj.echo(' namespace("production:%s")' % token, fg="green") + obj.echo( + "If you want to authorize other people to deploy new versions of this flow to " + "Airflow, they need to call" + ) + obj.echo(" airflow create --authorize %s" % token, fg="green") + obj.echo("when deploying this flow to Airflow for the first time.") + obj.echo( + 'See "Organizing Results" at https://docs.metaflow.org/ for more ' + "information about production tokens." + ) + obj.echo("") + store_token(token_prefix, token) + + return token + + +@click.group() +def cli(): + pass + + +@cli.group(help="Commands related to Airflow.") +@click.option( + "--name", + default=None, + type=str, + help="Airflow DAG name. The flow name is used instead if this option is not " + "specified", +) +@click.pass_obj +def airflow(obj, name=None): + obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) + obj.dag_name, obj.token_prefix, obj.is_project = resolve_dag_name(name) + + +@airflow.command(help="Compile a new version of this flow to Airflow DAG.") +@click.argument("file", required=True) +@click.option( + "--authorize", + default=None, + help="Authorize using this production token. You need this " + "when you are re-deploying an existing flow for the first " + "time. The token is cached in METAFLOW_HOME, so you only " + "need to specify this once.", +) +@click.option( + "--generate-new-token", + is_flag=True, + help="Generate a new production token for this flow. " + "This will move the production flow to a new namespace.", +) +@click.option( + "--new-token", + "given_token", + default=None, + help="Use the given production token for this flow. " + "This will move the production flow to the given namespace.", +) +@click.option( + "--tag", + "tags", + multiple=True, + default=None, + help="Annotate all objects produced by Airflow DAG executions " + "with the given tag. You can specify this option multiple " + "times to attach multiple tags.", +) +@click.option( + "--is-paused-upon-creation", + default=False, + is_flag=True, + help="Generated Airflow DAG is paused/unpaused upon creation.", +) +@click.option( + "--namespace", + "user_namespace", + default=None, + # TODO (savin): Identify the default namespace? + help="Change the namespace from the default to the given tag. " + "See run --help for more information.", +) +@click.option( + "--max-workers", + default=100, + show_default=True, + help="Maximum number of parallel processes.", +) +@click.option( + "--workflow-timeout", + default=None, + type=int, + help="Workflow timeout in seconds. Enforced only for scheduled DAGs.", +) +@click.option( + "--worker-pool", + default=None, + show_default=True, + help="Worker pool for Airflow DAG execution.", +) +@click.pass_obj +def create( + obj, + file, + authorize=None, + generate_new_token=False, + given_token=None, + tags=None, + is_paused_upon_creation=False, + user_namespace=None, + max_workers=None, + workflow_timeout=None, + worker_pool=None, +): + if os.path.abspath(sys.argv[0]) == os.path.abspath(file): + raise MetaflowException( + "Airflow DAG file name cannot be the same as flow file name" + ) + + obj.echo("Compiling *%s* to Airflow DAG..." % obj.dag_name, bold=True) + + token = resolve_token( + obj.dag_name, + obj.token_prefix, + obj, + authorize, + given_token, + generate_new_token, + obj.is_project, + ) + + flow = make_flow( + obj, + obj.dag_name, + token, + tags, + is_paused_upon_creation, + user_namespace, + max_workers, + workflow_timeout, + worker_pool, + file, + ) + with open(file, "w") as f: + f.write(flow.compile()) + + obj.echo( + "DAG *{dag_name}* " + "for flow *{name}* compiled to " + "Airflow successfully.\n".format(dag_name=obj.dag_name, name=current.flow_name), + bold=True, + ) + + +def make_flow( + obj, + dag_name, + production_token, + tags, + is_paused_upon_creation, + namespace, + max_workers, + workflow_timeout, + worker_pool, + file, +): + # Validate if the workflow is correctly parsed. + _validate_workflow( + obj.flow, obj.graph, obj.flow_datastore, obj.metadata, workflow_timeout + ) + + # Attach @kubernetes. + decorators._attach_decorators(obj.flow, [KubernetesDecorator.name]) + + decorators._init_step_decorators( + obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger + ) + + # Save the code package in the flow datastore so that both user code and + # metaflow package can be retrieved during workflow execution. + obj.package = MetaflowPackage( + obj.flow, obj.environment, obj.echo, obj.package_suffixes + ) + package_url, package_sha = obj.flow_datastore.save_data( + [obj.package.blob], len_hint=1 + )[0] + + return Airflow( + dag_name, + obj.graph, + obj.flow, + package_sha, + package_url, + obj.metadata, + obj.flow_datastore, + obj.environment, + obj.event_logger, + obj.monitor, + production_token, + tags=tags, + namespace=namespace, + username=get_username(), + max_workers=max_workers, + worker_pool=worker_pool, + workflow_timeout=workflow_timeout, + description=obj.flow.__doc__, + file_path=file, + is_paused_upon_creation=is_paused_upon_creation, + ) + + +def _validate_foreach_constraints(graph): + def traverse_graph(node, state): + if node.type == "foreach" and node.is_inside_foreach: + raise NotSupportedException( + "Step *%s* is a foreach step called within a foreach step. This type of graph is currently not supported with Airflow." + % node.name + ) + + if node.type == "foreach": + state["foreach_stack"] = [node.name] + + if node.type in ("start", "linear", "join", "foreach"): + if node.type == "linear" and node.is_inside_foreach: + state["foreach_stack"].append(node.name) + + if len(state["foreach_stack"]) > 2: + raise NotSupportedException( + "The foreach step *%s* created by step *%s* needs to have an immidiate join step. " + "Step *%s* is invalid since it is a linear step with a foreach. " + "This type of graph is currently not supported with Airflow." + % ( + state["foreach_stack"][1], + state["foreach_stack"][0], + state["foreach_stack"][-1], + ) + ) + + traverse_graph(graph[node.out_funcs[0]], state) + + elif node.type == "split": + for func in node.out_funcs: + traverse_graph(graph[func], state) + + traverse_graph(graph["start"], {}) + + +def _validate_workflow(flow, graph, flow_datastore, metadata, workflow_timeout): + # check for other compute related decorators. + _validate_foreach_constraints(graph) + for node in graph: + if any([d.name == "batch" for d in node.decorators]): + raise NotSupportedException( + "Step *%s* is marked for execution on AWS Batch with Airflow which isn't currently supported." + % node.name + ) + + if flow_datastore.TYPE != "s3": + raise AirflowException('Datastore of type "s3" required with `airflow create`') + + +def resolve_dag_name(name): + project = current.get("project_name") + is_project = False + + if project: + is_project = True + if name: + raise MetaflowException( + "--name is not supported for @projects. " "Use --branch instead." + ) + dag_name = current.project_flow_name + if dag_name and VALID_NAME.search(dag_name): + raise MetaflowException( + "Name '%s' contains invalid characters. Please construct a name using regex %s" + % (dag_name, VALID_NAME.pattern) + ) + project_branch = to_bytes(".".join((project, current.branch_name))) + token_prefix = ( + "mfprj-%s" + % to_unicode(base64.b32encode(sha1(project_branch).digest()))[:16] + ) + else: + if name and VALID_NAME.search(name): + raise MetaflowException( + "Name '%s' contains invalid characters. Please construct a name using regex %s" + % (name, VALID_NAME.pattern) + ) + dag_name = name if name else current.flow_name + token_prefix = dag_name + return dag_name, token_prefix.lower(), is_project diff --git a/metaflow/plugins/airflow/airflow_decorator.py b/metaflow/plugins/airflow/airflow_decorator.py new file mode 100644 index 00000000000..11bdecaaa8b --- /dev/null +++ b/metaflow/plugins/airflow/airflow_decorator.py @@ -0,0 +1,66 @@ +import json +import os +from metaflow.decorators import StepDecorator +from metaflow.metadata import MetaDatum + +from .airflow_utils import ( + TASK_ID_XCOM_KEY, + FOREACH_CARDINALITY_XCOM_KEY, +) + +K8S_XCOM_DIR_PATH = "/airflow/xcom" + + +def safe_mkdir(dir): + try: + os.makedirs(dir) + except FileExistsError: + pass + + +def push_xcom_values(xcom_dict): + safe_mkdir(K8S_XCOM_DIR_PATH) + with open(os.path.join(K8S_XCOM_DIR_PATH, "return.json"), "w") as f: + json.dump(xcom_dict, f) + + +class AirflowInternalDecorator(StepDecorator): + name = "airflow_internal" + + def task_pre_step( + self, + step_name, + task_datastore, + metadata, + run_id, + task_id, + flow, + graph, + retry_count, + max_user_code_retries, + ubf_context, + inputs, + ): + meta = {} + meta["airflow-dag-run-id"] = os.environ["METAFLOW_AIRFLOW_DAG_RUN_ID"] + meta["airflow-job-id"] = os.environ["METAFLOW_AIRFLOW_JOB_ID"] + entries = [ + MetaDatum( + field=k, value=v, type=k, tags=["attempt_id:{0}".format(retry_count)] + ) + for k, v in meta.items() + ] + + # Register book-keeping metadata for debugging. + metadata.register_metadata(run_id, step_name, task_id, entries) + + def task_finished( + self, step_name, flow, graph, is_task_ok, retry_count, max_user_code_retries + ): + # This will pass the xcom when the task finishes. + xcom_values = { + TASK_ID_XCOM_KEY: os.environ["METAFLOW_AIRFLOW_TASK_ID"], + } + if graph[step_name].type == "foreach": + xcom_values[FOREACH_CARDINALITY_XCOM_KEY] = flow._foreach_num_splits + push_xcom_values(xcom_values) diff --git a/metaflow/plugins/airflow/airflow_utils.py b/metaflow/plugins/airflow/airflow_utils.py new file mode 100644 index 00000000000..c3d94c7bcb1 --- /dev/null +++ b/metaflow/plugins/airflow/airflow_utils.py @@ -0,0 +1,646 @@ +import hashlib +import json +import sys +from collections import defaultdict +from datetime import datetime, timedelta + + +TASK_ID_XCOM_KEY = "metaflow_task_id" +FOREACH_CARDINALITY_XCOM_KEY = "metaflow_foreach_cardinality" +FOREACH_XCOM_KEY = "metaflow_foreach_indexes" +RUN_HASH_ID_LEN = 12 +TASK_ID_HASH_LEN = 8 +RUN_ID_PREFIX = "airflow" +AIRFLOW_FOREACH_SUPPORT_VERSION = "2.3.0" +AIRFLOW_MIN_SUPPORT_VERSION = "2.0.0" +KUBERNETES_PROVIDER_FOREACH_VERSION = "4.2.0" + + +class KubernetesProviderNotFound(Exception): + headline = "Kubernetes provider not found" + + +class ForeachIncompatibleException(Exception): + headline = "Airflow version is incompatible to support Metaflow foreach's." + + +class IncompatibleVersionException(Exception): + headline = "Metaflow is incompatible with current version of Airflow." + + def __init__(self, version_number) -> None: + msg = ( + "Airflow version %s is incompatible with Metaflow. Metaflow requires Airflow a minimum version %s" + % (version_number, AIRFLOW_MIN_SUPPORT_VERSION) + ) + super().__init__(msg) + + +class IncompatibleKubernetesProviderVersionException(Exception): + headline = ( + "Kubernetes Provider version is incompatible with Metaflow foreach's. " + "Install the provider via " + "`%s -m pip install apache-airflow-providers-cncf-kubernetes==%s`" + ) % (sys.executable, KUBERNETES_PROVIDER_FOREACH_VERSION) + + +class AirflowSensorNotFound(Exception): + headline = "Sensor package not found" + + +def create_absolute_version_number(version): + abs_version = None + # For all digits + if all(v.isdigit() for v in version.split(".")): + abs_version = sum( + [ + (10 ** (3 - idx)) * i + for idx, i in enumerate([int(v) for v in version.split(".")]) + ] + ) + # For first two digits + elif all(v.isdigit() for v in version.split(".")[:2]): + abs_version = sum( + [ + (10 ** (3 - idx)) * i + for idx, i in enumerate([int(v) for v in version.split(".")[:2]]) + ] + ) + return abs_version + + +def _validate_dyanmic_mapping_compatibility(): + from airflow.version import version + + af_ver = create_absolute_version_number(version) + if af_ver is None or af_ver < create_absolute_version_number( + AIRFLOW_FOREACH_SUPPORT_VERSION + ): + ForeachIncompatibleException( + "Please install airflow version %s to use Airflow's Dynamic task mapping functionality." + % AIRFLOW_FOREACH_SUPPORT_VERSION + ) + + +def get_kubernetes_provider_version(): + try: + from airflow.providers.cncf.kubernetes.get_provider_info import ( + get_provider_info, + ) + except ImportError as e: + raise KubernetesProviderNotFound( + "This DAG utilizes `KubernetesPodOperator`. " + "Install the Airflow Kubernetes provider using " + "`%s -m pip install apache-airflow-providers-cncf-kubernetes`" + % sys.executable + ) + return get_provider_info()["versions"][0] + + +def _validate_minimum_airflow_version(): + from airflow.version import version + + af_ver = create_absolute_version_number(version) + if af_ver is None or af_ver < create_absolute_version_number( + AIRFLOW_MIN_SUPPORT_VERSION + ): + raise IncompatibleVersionException(version) + + +def _check_foreach_compatible_kubernetes_provider(): + provider_version = get_kubernetes_provider_version() + ver = create_absolute_version_number(provider_version) + if ver is None or ver < create_absolute_version_number( + KUBERNETES_PROVIDER_FOREACH_VERSION + ): + raise IncompatibleKubernetesProviderVersionException() + + +class AIRFLOW_MACROS: + # run_id_creator is added via the `user_defined_filters` + RUN_ID = "%s-{{ [run_id, dag_run.dag_id] | run_id_creator }}" % RUN_ID_PREFIX + PARAMETERS = "{{ params | json_dump }}" + + # AIRFLOW_MACROS.TASK_ID will work for linear/branched workflows. + # ti.task_id is the stepname in metaflow code. + # AIRFLOW_MACROS.TASK_ID uses a jinja filter called `task_id_creator` which helps + # concatenate the string using a `/`. Since run-id will keep changing and stepname will be + # the same task id will change. Since airflow doesn't encourage dynamic rewriting of dags + # we can rename steps in a foreach with indexes (eg. `stepname-$index`) to create those steps. + # Hence : Foreachs will require some special form of plumbing. + # https://stackoverflow.com/questions/62962386/can-an-airflow-task-dynamically-generate-a-dag-at-runtime + TASK_ID = ( + "%s-{{ [run_id, ti.task_id, dag_run.dag_id, ti.map_index] | task_id_creator }}" + % RUN_ID_PREFIX + ) + + # Airflow run_ids are of the form : "manual__2022-03-15T01:26:41.186781+00:00" + # Such run-ids break the `metaflow.util.decompress_list`; this is why we hash the runid + # We do echo -n because emits line breaks and we dont want to consider that since it we want same hash value when retrieved in python. + RUN_ID_SHELL = ( + "%s-$(echo -n {{ run_id }}-{{ dag_run.dag_id }} | md5sum | awk '{print $1}' | awk '{print substr ($0, 0, %s)}')" + % (RUN_ID_PREFIX, str(RUN_HASH_ID_LEN)) + ) + + ATTEMPT = "{{ task_instance.try_number - 1 }}" + + AIRFLOW_RUN_ID = "{{ run_id }}" + + AIRFLOW_JOB_ID = "{{ ti.job_id }}" + + FOREACH_SPLIT_INDEX = "{{ ti.map_index }}" + + +class SensorNames: + EXTERNAL_TASK_SENSOR = "ExternalTaskSensor" + S3_SENSOR = "S3KeySensor" + SQL_SENSOR = "SQLSensor" + + @classmethod + def get_supported_sensors(cls): + return list(cls.__dict__.values()) + + +def run_id_creator(val): + # join `[dag-id,run-id]` of airflow dag. + return hashlib.md5("-".join([str(x) for x in val]).encode("utf-8")).hexdigest()[ + :RUN_HASH_ID_LEN + ] + + +def task_id_creator(val): + # join `[dag-id,run-id]` of airflow dag. + return hashlib.md5("-".join([str(x) for x in val]).encode("utf-8")).hexdigest()[ + :TASK_ID_HASH_LEN + ] + + +def id_creator(val, hash_len): + # join `[dag-id,run-id]` of airflow dag. + return hashlib.md5("-".join([str(x) for x in val]).encode("utf-8")).hexdigest()[ + :hash_len + ] + + +def json_dump(val): + return json.dumps(val) + + +class AirflowDAGArgs(object): + + # `_arg_types` is a dictionary which represents the types of the arguments of an Airflow `DAG`. + # `_arg_types` is used when parsing types back from the configuration json. + # It doesn't cover all the arguments but covers many of the important one which can come from the cli. + _arg_types = { + "dag_id": str, + "description": str, + "schedule_interval": str, + "start_date": datetime, + "catchup": bool, + "tags": list, + "dagrun_timeout": timedelta, + "default_args": { + "owner": str, + "depends_on_past": bool, + "email": list, + "email_on_failure": bool, + "email_on_retry": bool, + "retries": int, + "retry_delay": timedelta, + "queue": str, # which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues. + "pool": str, # the slot pool this task should run in, slot pools are a way to limit concurrency for certain tasks + "priority_weight": int, + "wait_for_downstream": bool, + "sla": timedelta, + "execution_timeout": timedelta, + "trigger_rule": str, + }, + } + + # Reference for user_defined_filters : https://stackoverflow.com/a/70175317 + filters = dict( + task_id_creator=lambda v: task_id_creator(v), + json_dump=lambda val: json_dump(val), + run_id_creator=lambda val: run_id_creator(val), + join_list=lambda x: ",".join(list(x)), + ) + + def __init__(self, **kwargs): + self._args = kwargs + + @property + def arguments(self): + return dict(**self._args, user_defined_filters=self.filters) + + def serialize(self): + def parse_args(dd): + data_dict = {} + for k, v in dd.items(): + if isinstance(v, dict): + data_dict[k] = parse_args(v) + elif isinstance(v, datetime): + data_dict[k] = v.isoformat() + elif isinstance(v, timedelta): + data_dict[k] = dict(seconds=v.total_seconds()) + else: + data_dict[k] = v + return data_dict + + return parse_args(self._args) + + @classmethod + def deserialize(cls, data_dict): + def parse_args(dd, type_check_dict): + kwrgs = {} + for k, v in dd.items(): + if k not in type_check_dict: + kwrgs[k] = v + elif isinstance(v, dict) and isinstance(type_check_dict[k], dict): + kwrgs[k] = parse_args(v, type_check_dict[k]) + elif type_check_dict[k] == datetime: + kwrgs[k] = datetime.fromisoformat(v) + elif type_check_dict[k] == timedelta: + kwrgs[k] = timedelta(**v) + else: + kwrgs[k] = v + return kwrgs + + return cls(**parse_args(data_dict, cls._arg_types)) + + +def _kubernetes_pod_operator_args(operator_args): + from kubernetes import client + + from airflow.kubernetes.secret import Secret + + # Set dynamic env variables like run-id, task-id etc from here. + secrets = [ + Secret("env", secret, secret) for secret in operator_args.get("secrets", []) + ] + args = operator_args + args.update( + { + "secrets": secrets, + # Question for (savin): + # Default timeout in airflow is 120. I can remove `startup_timeout_seconds` for now. how should we expose it to the user? + } + ) + # We need to explicity add the `client.V1EnvVar` over here because + # `pod_runtime_info_envs` doesn't accept arguments in dictionary form and strictly + # Requires objects of type `client.V1EnvVar` + additional_env_vars = [ + client.V1EnvVar( + name=k, + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector(field_path=str(v)) + ), + ) + for k, v in { + "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace", + "METAFLOW_KUBERNETES_POD_NAME": "metadata.name", + "METAFLOW_KUBERNETES_POD_ID": "metadata.uid", + "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName", + }.items() + ] + args["pod_runtime_info_envs"] = additional_env_vars + + resources = args.get("resources") + # KubernetesPodOperator version 4.2.0 renamed `resources` to + # `container_resources` (https://github.com/apache/airflow/pull/24673) / (https://github.com/apache/airflow/commit/45f4290712f5f779e57034f81dbaab5d77d5de85) + # This was done because `KubernetesPodOperator` didn't play nice with dynamic task mapping + # and they had to deprecate the `resources` argument. Hence the below codepath checks for the version of `KubernetesPodOperator` + # and then sets the argument. If the version < 4.2.0 then we set the argument as `resources`. + # If it is > 4.2.0 then we set the argument as `container_resources` + # The `resources` argument of KuberentesPodOperator is going to be deprecated soon in the future. + # So we will only use it for `KuberentesPodOperator` version < 4.2.0 + # The `resources` argument will also not work for foreach's. + provider_version = get_kubernetes_provider_version() + k8s_op_ver = create_absolute_version_number(provider_version) + if k8s_op_ver is None or k8s_op_ver < create_absolute_version_number( + KUBERNETES_PROVIDER_FOREACH_VERSION + ): + # Since the provider version is less than `4.2.0` so we need to use the `resources` argument + # We need to explicitly parse `resources`/`container_resources` to k8s.V1ResourceRequirements otherwise airflow tries + # to parse dictionaries to `airflow.providers.cncf.kubernetes.backcompat.pod.Resources` object via + # `airflow.providers.cncf.kubernetes.backcompat.backward_compat_converts.convert_resources` function. + # This fails many times since the dictionary structure it expects is not the same as `client.V1ResourceRequirements`. + args["resources"] = client.V1ResourceRequirements( + requests=resources["requests"], + limits=None if "limits" not in resources else resources["limits"], + ) + else: # since the provider version is greater than `4.2.0` so should use the `container_resources` argument + args["container_resources"] = client.V1ResourceRequirements( + requests=resources["requests"], + limits=None if "limits" not in resources else resources["limits"], + ) + del args["resources"] + + if operator_args.get("execution_timeout"): + args["execution_timeout"] = timedelta( + **operator_args.get( + "execution_timeout", + ) + ) + if operator_args.get("retry_delay"): + args["retry_delay"] = timedelta(**operator_args.get("retry_delay")) + return args + + +def _parse_sensor_args(name, kwargs): + if name == SensorNames.EXTERNAL_TASK_SENSOR: + if "execution_delta" in kwargs: + if type(kwargs["execution_delta"]) == dict: + kwargs["execution_delta"] = timedelta(**kwargs["execution_delta"]) + else: + del kwargs["execution_delta"] + return kwargs + + +def _get_sensor(name): + # from airflow import XComArg + # XComArg() + if name == SensorNames.EXTERNAL_TASK_SENSOR: + # ExternalTaskSensors uses an execution_date of a dag to + # determine the appropriate DAG. + # This is set to the exact date the current dag gets executed on. + # For example if "DagA" (Upstream DAG) got scheduled at + # 12 Jan 4:00 PM PDT then "DagB"(current DAG)'s task sensor will try to + # look for a "DagA" that got executed at 12 Jan 4:00 PM PDT **exactly**. + # They also support a `execution_timeout` argument to + from airflow.sensors.external_task_sensor import ExternalTaskSensor + + return ExternalTaskSensor + elif name == SensorNames.S3_SENSOR: + try: + from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor + except ImportError: + raise AirflowSensorNotFound( + "This DAG requires a `S3KeySensor`. " + "Install the Airflow AWS provider using : " + "`pip install apache-airflow-providers-amazon`" + ) + return S3KeySensor + elif name == SensorNames.SQL_SENSOR: + from airflow.sensors.sql import SqlSensor + + return SqlSensor + + +def get_metaflow_kuberentes_operator(): + try: + from airflow.contrib.operators.kubernetes_pod_operator import ( + KubernetesPodOperator, + ) + except ImportError: + try: + from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( + KubernetesPodOperator, + ) + except ImportError as e: + raise KubernetesProviderNotFound( + "This DAG utilizes `KubernetesPodOperator`. " + "Install the Airflow Kubernetes provider using " + "`%s -m pip install apache-airflow-providers-cncf-kubernetes`" + % sys.executable + ) + + class MetaflowKubernetesOperator(KubernetesPodOperator): + """ + ## Why Inherit the `KubernetesPodOperator` class ? + + Two key reasons : + + 1. So that we can override the `execute` method. + The only change we introduce to the method is to explicitly modify xcom relating to `return_values`. + We do this so that the `XComArg` object can work with `expand` function. + + 2. So that we can introduce an keyword argument named `mapper_arr`. + This keyword argument can help as a dummy argument for the `KubernetesPodOperator.partial().expand` method. Any Airflow Operator can be dynamically mapped to runtime artifacts using `Operator.partial(**kwargs).extend(**mapper_kwargs)` post the introduction of [Dynamic Task Mapping](https://airflow.apache.org/docs/apache-airflow/stable/concepts/dynamic-task-mapping.html). + The `expand` function takes keyword arguments taken by the operator. + + ## Why override the `execute` method ? + + When we dynamically map vanilla Airflow operators with artifacts generated at runtime, we need to pass that information via `XComArg` to a operator's keyword argument in the `expand` [function](https://airflow.apache.org/docs/apache-airflow/stable/concepts/dynamic-task-mapping.html#mapping-over-result-of-classic-operators). + The `XComArg` object retrieves XCom values for a particular task based on a `key`, the default key being `return_values`. + Oddly dynamic task mapping [doesn't support XCom values from any other key except](https://github.com/apache/airflow/blob/8a34d25049a060a035d4db4a49cd4a0d0b07fb0b/airflow/models/mappedoperator.py#L150) `return_values` + The values of XCom passed by the `KubernetesPodOperator` are mapped to the `return_values` XCom key. + + The biggest problem this creates is that the values of the Foreach cadinality are stored inside the dictionary of `return_values` and cannot be accessed trivially like : `XComArg(task)['foreach_key']` since they are resolved during runtime. + This puts us in a bind since the only xcom we can retrieve is the full dictionary and we cannot pass that as the iteratable for the mapper tasks. + Hence we inherit the `execute` method and push custom xcom keys (needed by downstream tasks such as metaflow taskids) and modify `return_values` captured from the container whenever a foreach related xcom is passed. + When we encounter a foreach xcom we resolve the cardinality which is passed to an actual list and return that as `return_values`. + This is later useful in the `Workflow.compile` where the operator's `expand` method is called and we are able to retrieve the xcom value. + """ + + def __init__(self, *args, mapper_arr=None, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.mapper_arr = mapper_arr + + def execute(self, context): + result = super().execute(context) + if result is None: + return + ti = context["ti"] + if TASK_ID_XCOM_KEY in result: + ti.xcom_push( + key=TASK_ID_XCOM_KEY, + value=result[TASK_ID_XCOM_KEY], + ) + if FOREACH_CARDINALITY_XCOM_KEY in result: + return list(range(result[FOREACH_CARDINALITY_XCOM_KEY])) + + return MetaflowKubernetesOperator + + +class AirflowTask(object): + def __init__( + self, name, operator_type="kubernetes", flow_name=None, is_mapper_node=False + ): + self.name = name + self._is_mapper_node = is_mapper_node + self._operator_args = None + self._operator_type = operator_type + self._flow_name = flow_name + + @property + def is_mapper_node(self): + return self._is_mapper_node + + def set_operator_args(self, **kwargs): + self._operator_args = kwargs + return self + + def _make_sensor(self): + TaskSensor = _get_sensor(self._operator_type) + return TaskSensor( + task_id=self.name, + **_parse_sensor_args(self._operator_type, self._operator_args) + ) + + def to_dict(self): + return { + "name": self.name, + "is_mapper_node": self._is_mapper_node, + "operator_type": self._operator_type, + "operator_args": self._operator_args, + } + + @classmethod + def from_dict(cls, task_dict, flow_name=None): + op_args = {} if not "operator_args" in task_dict else task_dict["operator_args"] + is_mapper_node = ( + False if "is_mapper_node" not in task_dict else task_dict["is_mapper_node"] + ) + return cls( + task_dict["name"], + is_mapper_node=is_mapper_node, + operator_type=task_dict["operator_type"] + if "operator_type" in task_dict + else "kubernetes", + flow_name=flow_name, + ).set_operator_args(**op_args) + + def _kubenetes_task(self): + MetaflowKubernetesOperator = get_metaflow_kuberentes_operator() + k8s_args = _kubernetes_pod_operator_args(self._operator_args) + return MetaflowKubernetesOperator(**k8s_args) + + def _kubernetes_mapper_task(self): + MetaflowKubernetesOperator = get_metaflow_kuberentes_operator() + k8s_args = _kubernetes_pod_operator_args(self._operator_args) + return MetaflowKubernetesOperator.partial(**k8s_args) + + def to_task(self): + if self._operator_type == "kubernetes": + if not self.is_mapper_node: + return self._kubenetes_task() + else: + return self._kubernetes_mapper_task() + elif self._operator_type in SensorNames.get_supported_sensors(): + return self._make_sensor() + + +class Workflow(object): + def __init__(self, file_path=None, graph_structure=None, metadata=None, **kwargs): + self._dag_instantiation_params = AirflowDAGArgs(**kwargs) + self._file_path = file_path + self._metadata = metadata + tree = lambda: defaultdict(tree) + self.states = tree() + self.metaflow_params = None + self.graph_structure = graph_structure + + def set_parameters(self, params): + self.metaflow_params = params + + def add_state(self, state): + self.states[state.name] = state + + def to_dict(self): + return dict( + metadata=self._metadata, + graph_structure=self.graph_structure, + states={s: v.to_dict() for s, v in self.states.items()}, + dag_instantiation_params=self._dag_instantiation_params.serialize(), + file_path=self._file_path, + metaflow_params=self.metaflow_params, + ) + + def to_json(self): + return json.dumps(self.to_dict()) + + @classmethod + def from_dict(cls, data_dict): + re_cls = cls( + file_path=data_dict["file_path"], + graph_structure=data_dict["graph_structure"], + metadata={} if "metadata" not in data_dict else data_dict["metadata"], + ) + re_cls._dag_instantiation_params = AirflowDAGArgs.deserialize( + data_dict["dag_instantiation_params"] + ) + + for sd in data_dict["states"].values(): + re_cls.add_state( + AirflowTask.from_dict( + sd, flow_name=re_cls._dag_instantiation_params.arguments["dag_id"] + ) + ) + re_cls.set_parameters(data_dict["metaflow_params"]) + return re_cls + + @classmethod + def from_json(cls, json_string): + data = json.loads(json_string) + return cls.from_dict(data) + + def _construct_params(self): + from airflow.models.param import Param + + if self.metaflow_params is None: + return {} + param_dict = {} + for p in self.metaflow_params: + name = p["name"] + del p["name"] + param_dict[name] = Param(**p) + return param_dict + + def compile(self): + from airflow import DAG + from airflow import XComArg + + _validate_minimum_airflow_version() + + if self._metadata["contains_foreach"]: + _validate_dyanmic_mapping_compatibility() + # We need to verify if KubernetesPodOperator is of version > 4.2.0 to support foreachs / dynamic task mapping. + # If the dag uses dynamic Task mapping then we throw an error since the `resources` argument in the `KuberentesPodOperator` + # doesn't work for dynamic task mapping for `KuberentesPodOperator` version < 4.2.0. + # For more context check this issue : https://github.com/apache/airflow/issues/24669 + _check_foreach_compatible_kubernetes_provider() + + params_dict = self._construct_params() + # DAG Params can be seen here : + # https://airflow.apache.org/docs/apache-airflow/2.0.0/_api/airflow/models/dag/index.html#airflow.models.dag.DAG + # Airflow 2.0.0 Allows setting Params. + dag = DAG(params=params_dict, **self._dag_instantiation_params.arguments) + dag.fileloc = self._file_path if self._file_path is not None else dag.fileloc + + def add_node(node, parents, dag): + """ + A recursive function to traverse the specialized + graph_structure datastructure. + """ + if type(node) == str: + task = self.states[node].to_task() + if parents: + for parent in parents: + # Handle foreach nodes. + if self.states[node].is_mapper_node: + task = task.expand(mapper_arr=XComArg(parent)) + parent >> task + return [task] # Return Parent + + # this means a split from parent + if type(node) == list: + # this means branching since everything within the list is a list + if all(isinstance(n, list) for n in node): + curr_parents = parents + parent_list = [] + for node_list in node: + last_parent = add_node(node_list, curr_parents, dag) + parent_list.extend(last_parent) + return parent_list + else: + # this means no branching and everything within the list is not a list and can be actual nodes. + curr_parents = parents + for node_x in node: + curr_parents = add_node(node_x, curr_parents, dag) + return curr_parents + + with dag: + parent = None + for node in self.graph_structure: + parent = add_node(node, parent, dag) + + return dag diff --git a/metaflow/plugins/airflow/dag.py b/metaflow/plugins/airflow/dag.py new file mode 100644 index 00000000000..2720fe40e0e --- /dev/null +++ b/metaflow/plugins/airflow/dag.py @@ -0,0 +1,9 @@ +# Deployed on {{deployed_on}} + +CONFIG = {{{config}}} + +{{{utils}}} + +dag = Workflow.from_dict(CONFIG).compile() +with dag: + pass diff --git a/metaflow/plugins/airflow/exception.py b/metaflow/plugins/airflow/exception.py new file mode 100644 index 00000000000..a76a755e22c --- /dev/null +++ b/metaflow/plugins/airflow/exception.py @@ -0,0 +1,12 @@ +from metaflow.exception import MetaflowException + + +class AirflowException(MetaflowException): + headline = "Airflow Exception" + + def __init__(self, msg): + super().__init__(msg) + + +class NotSupportedException(MetaflowException): + headline = "Not yet supported with Airflow" diff --git a/metaflow/plugins/airflow/plumbing/set_parameters.py b/metaflow/plugins/airflow/plumbing/set_parameters.py new file mode 100644 index 00000000000..7a2e4dd3112 --- /dev/null +++ b/metaflow/plugins/airflow/plumbing/set_parameters.py @@ -0,0 +1,21 @@ +import os +import json +import sys + + +def export_parameters(output_file): + input = json.loads(os.environ.get("METAFLOW_PARAMETERS", "{}")) + with open(output_file, "w") as f: + for k in input: + # Replace `-` with `_` is parameter names since `-` isn't an + # allowed character for environment variables. cli.py will + # correctly translate the replaced `-`s. + f.write( + "export METAFLOW_INIT_%s=%s\n" + % (k.upper().replace("-", "_"), json.dumps(input[k])) + ) + os.chmod(output_file, 509) + + +if __name__ == "__main__": + export_parameters(sys.argv[1]) diff --git a/metaflow/plugins/airflow/sensors/__init__.py b/metaflow/plugins/airflow/sensors/__init__.py new file mode 100644 index 00000000000..02952d0c9a4 --- /dev/null +++ b/metaflow/plugins/airflow/sensors/__init__.py @@ -0,0 +1,9 @@ +from .external_task_sensor import ExternalTaskSensorDecorator +from .s3_sensor import S3KeySensorDecorator +from .sql_sensor import SQLSensorDecorator + +SUPPORTED_SENSORS = [ + ExternalTaskSensorDecorator, + S3KeySensorDecorator, + SQLSensorDecorator, +] diff --git a/metaflow/plugins/airflow/sensors/base_sensor.py b/metaflow/plugins/airflow/sensors/base_sensor.py new file mode 100644 index 00000000000..9412072cd23 --- /dev/null +++ b/metaflow/plugins/airflow/sensors/base_sensor.py @@ -0,0 +1,74 @@ +import uuid +from metaflow.decorators import FlowDecorator +from ..exception import AirflowException +from ..airflow_utils import AirflowTask, id_creator, TASK_ID_HASH_LEN + + +class AirflowSensorDecorator(FlowDecorator): + """ + Base class for all Airflow sensor decorators. + """ + + allow_multiple = True + + defaults = dict( + timeout=3600, + poke_interval=60, + mode="reschedule", + exponential_backoff=True, + pool=None, + soft_fail=False, + name=None, + description=None, + ) + + operator_type = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._airflow_task_name = None + self._id = str(uuid.uuid4()) + + def serialize_operator_args(self): + """ + Subclasses will parse the decorator arguments to + Airflow task serializable arguments. + """ + task_args = dict(**self.attributes) + del task_args["name"] + if task_args["description"] is not None: + task_args["doc"] = task_args["description"] + del task_args["description"] + task_args["do_xcom_push"] = True + return task_args + + def create_task(self): + task_args = self.serialize_operator_args() + return AirflowTask( + self._airflow_task_name, + operator_type=self.operator_type, + ).set_operator_args(**{k: v for k, v in task_args.items() if v is not None}) + + def validate(self): + """ + Validate if the arguments for the sensor are correct. + """ + # If there is no name set then auto-generate the name. This is done because there can be more than + # one `AirflowSensorDecorator` of the same type. + if self.attributes["name"] is None: + deco_index = [ + d._id + for d in self._flow_decorators + if issubclass(d.__class__, AirflowSensorDecorator) + ].index(self._id) + self._airflow_task_name = "%s-%s" % ( + self.operator_type, + id_creator([self.operator_type, str(deco_index)], TASK_ID_HASH_LEN), + ) + else: + self._airflow_task_name = self.attributes["name"] + + def flow_init( + self, flow, graph, environment, flow_datastore, metadata, logger, echo, options + ): + self.validate() diff --git a/metaflow/plugins/airflow/sensors/external_task_sensor.py b/metaflow/plugins/airflow/sensors/external_task_sensor.py new file mode 100644 index 00000000000..649edba706c --- /dev/null +++ b/metaflow/plugins/airflow/sensors/external_task_sensor.py @@ -0,0 +1,94 @@ +from .base_sensor import AirflowSensorDecorator +from ..airflow_utils import SensorNames +from ..exception import AirflowException +from datetime import timedelta + + +AIRFLOW_STATES = dict( + QUEUED="queued", + RUNNING="running", + SUCCESS="success", + SHUTDOWN="shutdown", # External request to shut down, + FAILED="failed", + UP_FOR_RETRY="up_for_retry", + UP_FOR_RESCHEDULE="up_for_reschedule", + UPSTREAM_FAILED="upstream_failed", + SKIPPED="skipped", +) + + +class ExternalTaskSensorDecorator(AirflowSensorDecorator): + operator_type = SensorNames.EXTERNAL_TASK_SENSOR + # Docs: + # https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/external_task/index.html#airflow.sensors.external_task.ExternalTaskSensor + name = "airflow_external_task_sensor" + defaults = dict( + **AirflowSensorDecorator.defaults, + external_dag_id=None, + external_task_ids=None, + allowed_states=[AIRFLOW_STATES["SUCCESS"]], + failed_states=None, + execution_delta=None, + check_existence=True, + # We cannot add `execution_date_fn` as it requires a python callable. + # Passing around a python callable is non-trivial since we are passing a + # callable from metaflow-code to airflow python script. Since we cannot + # transfer dependencies of the callable, we cannot gaurentee that the callable + # behave exactly as the user expects + ) + + def serialize_operator_args(self): + task_args = super().serialize_operator_args() + if task_args["execution_delta"] is not None: + task_args["execution_delta"] = dict( + seconds=task_args["execution_delta"].total_seconds() + ) + return task_args + + def validate(self): + if self.attributes["external_dag_id"] is None: + raise AirflowException( + "`%s` argument of `@%s`cannot be `None`." + % ("external_dag_id", self.name) + ) + + if type(self.attributes["allowed_states"]) == str: + if self.attributes["allowed_states"] not in list(AIRFLOW_STATES.values()): + raise AirflowException( + "`%s` is an invalid input of `%s` for `@%s`. Accepted values are %s" + % ( + str(self.attributes["allowed_states"]), + "allowed_states", + self.name, + ", ".join(list(AIRFLOW_STATES.values())), + ) + ) + elif type(self.attributes["allowed_states"]) == list: + enum_not_matched = [ + x + for x in self.attributes["allowed_states"] + if x not in list(AIRFLOW_STATES.values()) + ] + if len(enum_not_matched) > 0: + raise AirflowException( + "`%s` is an invalid input of `%s` for `@%s`. Accepted values are %s" + % ( + str(" OR ".join(["'%s'" % i for i in enum_not_matched])), + "allowed_states", + self.name, + ", ".join(list(AIRFLOW_STATES.values())), + ) + ) + else: + self.attributes["allowed_states"] = [AIRFLOW_STATES["SUCCESS"]] + + if self.attributes["execution_delta"] is not None: + if not isinstance(self.attributes["execution_delta"], timedelta): + raise AirflowException( + "`%s` is an invalid input type of `execution_delta` for `@%s`. Accepted type is `datetime.timedelta`" + % ( + str(type(self.attributes["execution_delta"])), + self.name, + ) + ) + super().validate() diff --git a/metaflow/plugins/airflow/sensors/s3_sensor.py b/metaflow/plugins/airflow/sensors/s3_sensor.py new file mode 100644 index 00000000000..b4f7ae5b6de --- /dev/null +++ b/metaflow/plugins/airflow/sensors/s3_sensor.py @@ -0,0 +1,26 @@ +from .base_sensor import AirflowSensorDecorator +from ..airflow_utils import SensorNames +from ..exception import AirflowException + + +class S3KeySensorDecorator(AirflowSensorDecorator): + name = "airflow_s3_key_sensor" + operator_type = SensorNames.S3_SENSOR + # Arg specification can be found here : + # https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/_api/airflow/providers/amazon/aws/sensors/s3/index.html#airflow.providers.amazon.aws.sensors.s3.S3KeySensor + defaults = dict( + **AirflowSensorDecorator.defaults, + bucket_key=None, # Required + bucket_name=None, + wildcard_match=False, + aws_conn_id=None, + verify=None, # `verify (Optional[Union[str, bool]])` Whether or not to verify SSL certificates for S3 connection. + # `verify` is a airflow variable. + ) + + def validate(self): + if self.attributes["bucket_key"] is None: + raise AirflowException( + "`bucket_key` for `@%s`cannot be empty." % (self.name) + ) + super().validate() diff --git a/metaflow/plugins/airflow/sensors/sql_sensor.py b/metaflow/plugins/airflow/sensors/sql_sensor.py new file mode 100644 index 00000000000..c97c41b283e --- /dev/null +++ b/metaflow/plugins/airflow/sensors/sql_sensor.py @@ -0,0 +1,31 @@ +from .base_sensor import AirflowSensorDecorator +from ..airflow_utils import SensorNames +from ..exception import AirflowException + + +class SQLSensorDecorator(AirflowSensorDecorator): + name = "airflow_sql_sensor" + operator_type = SensorNames.SQL_SENSOR + # Arg specification can be found here : + # https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/sql/index.html#airflow.sensors.sql.SqlSensor + defaults = dict( + **AirflowSensorDecorator.defaults, + conn_id=None, + sql=None, + # success = None, # sucess/failure require callables. Wont be supported at start since not serialization friendly. + # failure = None, + parameters=None, + fail_on_empty=True, + ) + + def validate(self): + if self.attributes["conn_id"] is None: + raise AirflowException( + "`%s` argument of `@%s`cannot be `None`." % ("conn_id", self.name) + ) + raise _arg_exception("conn_id", self.name, None) + if self.attributes["sql"] is None: + raise AirflowException( + "`%s` argument of `@%s`cannot be `None`." % ("sql", self.name) + ) + super().validate()