Skip to content

Commit

Permalink
Merge pull request #28 from alexuvarovskyi/HW7_PR6
Browse files Browse the repository at this point in the history
HW7 PR6 Airflow Inference Pipeline
  • Loading branch information
alexuvarovskyi authored Nov 2, 2024
2 parents 2a35712 + 0a482ee commit 1c6e539
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
119 changes: 119 additions & 0 deletions orchestration/airflow/dags/airflow_inference_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
from datetime import datetime

from airflow import DAG
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s


IMAGE = "alexuvarovskii/training_mlp:latest"
WANDB_PROJECT = "huggingface"
WANDB_API_KEY = os.environ.get("WANDB_API_KEY")

# STORAGE = "training-storage"
STORAGE = "airflow-pipeline-pvc"
AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")


volume = k8s.V1Volume(
name=STORAGE,
persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
claim_name=STORAGE
),
)

volume_mount = k8s.V1VolumeMount(name=STORAGE, mount_path="/tmp", sub_path=None)

with DAG(
start_date=datetime(2023, 10, 1),
catchup=False,
schedule=None,
dag_id="inference_dag",
) as dag:

download_data_operator = KubernetesPodOperator(
name='download_data_from_s3',
image=IMAGE,
cmds=[
"python",
"utils/pull_data.py",
"--access_key", AWS_ACCESS_KEY_ID,
"--secret_key", AWS_SECRET_ACCESS_KEY,
"--s3_bucket_name", "mlp-data-2024",
"--s3_dir_name", "data_mlp/val_50",
"--local_dir_name", "/tmp/data",
],
task_id='download_data_from_s3',
in_cluster=False,
is_delete_operator_pod=False,
startup_timeout_seconds=600,
namespace="default2",
volumes=[volume],
volume_mounts=[volume_mount],
)

download_model_operator = KubernetesPodOperator(
name='download_model_from_s3',
image=IMAGE,
cmds=[
"python",
"utils/pull_data.py",
"--access_key", AWS_ACCESS_KEY_ID,
"--secret_key", AWS_SECRET_ACCESS_KEY,
"--s3_bucket_name", "mlp-data-2024",
"--s3_dir_name", "rtdetr_test",
"--local_dir_name", "/tmp/model",
],
task_id='download_model_from_s3',
volumes=[volume],
volume_mounts=[volume_mount],
in_cluster=False,
is_delete_operator_pod=False,
startup_timeout_seconds=600,
namespace="default2",
)

inference_operator = KubernetesPodOperator(
name='inference_model',
image=IMAGE,
cmds=[
"python",
"src/infer_model_cli.py",
"--model_path", "/tmp/model",
"--data_path", "/tmp/data",
"--ann_save_path", "/tmp/ann",
],
task_id='inference_model',
in_cluster=False,
is_delete_operator_pod=False,
startup_timeout_seconds=600,
namespace="default2",
volumes=[volume],
volume_mounts=[volume_mount],
)

upload_results_operator = KubernetesPodOperator(
name='upload_results_to_s3',
image=IMAGE,
cmds=[
"python",
"utils/upload_data.py",
"--access_key", AWS_ACCESS_KEY_ID,
"--secret_key", AWS_SECRET_ACCESS_KEY,
"--s3_bucket_name", "mlp-data-2024",
"--s3_dir_name", "results_airflow",
"--local_dir_name", "/tmp/ann",
],
task_id='upload_results_to_s3',
volumes=[volume],
volume_mounts=[volume_mount],
in_cluster=False,
is_delete_operator_pod=False,
startup_timeout_seconds=600,
namespace="default2",
)

download_data_operator >> download_model_operator >> inference_operator >> upload_results_operator


15 changes: 15 additions & 0 deletions training/src/infer_model_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from inference import inference_model
import argparse


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--ann_save_path", type=str, required=True)
parser.add_argument("--conf_thresh", type=float, default=0.2)
return parser.parse_args()

if __name__ == "__main__":
args = parse_args()
inference_model(args.model_path, args.data_path, args.ann_save_path)

0 comments on commit 1c6e539

Please sign in to comment.