Skip to content

Commit

Permalink
Tweeks
Browse files Browse the repository at this point in the history
  • Loading branch information
valayDave committed Mar 11, 2022
1 parent a9f0468 commit db074b8
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def to_dict(self):


class AirflowTask(object):
def __init__(self, name):
def __init__(self, name, operator_args=None, operator_type="kubernetes"):
self.name = name
self._operator_args = operator_args
self._operator_type = operator_type
self._next = None
self._operator = None
self._operator_args = None

@property
def next_state(self):
Expand All @@ -102,22 +102,42 @@ def to_dict(self):
return {
"name": self.name,
"next": self._next,
"operator_type": self._operator_type,
"operator_args": self._operator_args,
}

@classmethod
def from_dict(cls, jsd):
return cls(jsd["name"]).next(jsd["next"])

def to_task(self):
# todo fix
from airflow.operators.bash import BashOperator
return cls(
jsd["name"],
operator_type=jsd["operator_type"]
if "operator_type" in jsd
else "kubernetes",
).next(jsd["next"])

def _kubenetes_task(self):
from airflow.contrib.operators.kubernetes_pod_operator import (
KubernetesPodOperator,
)

return BashOperator(
return KubernetesPodOperator(
namespace="airflow",
image="python",
cmds=["python", "-c"],
arguments=["print('{{ task }}')"],
labels={"foo": "bar"},
image_pull_policy="Always",
name=self.name,
task_id=self.name,
depends_on_past=True,
bash_command="sleep 1",
is_delete_operator_pod=True,
get_logs=True,
)

def to_task(self):
# todo fix
if self._operator_type == "kubernetes":
return self._kubenetes_task()


class Workflow(object):
def __init__(self, file_path=None, **kwargs):
Expand Down

0 comments on commit db074b8

Please sign in to comment.