Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples] Temporal example #4017

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions examples/temporal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Running SkyPilot Tasks in Temporal Workflows

This example demonstrates how to launch SkyPilot tasks and manage them in a Temporal workflow.

<p align="center">
<img src="https://i.imgur.com/rxlO2pJ.png" width="800">
</p>

All activities, such as launching clusters, executing tasks, and tearing down clusters, are run on the same worker, eliminating the need for SkyPilot's state management across multiple workers.

## Defining the Tasks

We will define the following tasks to mock a training workflow:
1. **`data_preprocessing.yaml`**: Generates data and writes it to a bucket.
2. **`train.yaml`**: Trains a model using the data in the bucket.
3. **`eval.yaml`**: Evaluates the model and writes the evaluation results to the bucket.

These tasks are defined in the [mock_training_workflow](https://github.com/romilbhardwaj/mock_train_workflow) repository. The repository is cloned during the workflow to execute the tasks.

## Workflow Overview

We define a Temporal workflow consisting of the following steps:

1. **Clone the repository containing tasks** using `git`.
2. **Launch a SkyPilot cluster** to run the data preprocessing job.
3. **Terminate the cluster** after preprocessing.
4. **Launch another cluster** for training the model.
5. **Execute an evaluation task** on the same training cluster.
6. **Terminate the cluster** after evaluation.

### Temporal Activities

These steps are implemented as Temporal activities, which are functions that can be executed by the Temporal worker:

- **`run_sky_launch`**: Launches a SkyPilot cluster with a specified configuration.
- **`run_sky_down`**: Terminates the specified SkyPilot cluster.
- **`run_sky_exec`**: Executes a task on an existing SkyPilot cluster.
- **`run_git_clone`**: Clones a Git repository to a specified location.

### Single Worker Execution

In this workflow, all tasks are handled by the same Temporal worker. This simplifies the workflow, as SkyPilot’s internal state does not need to be transferred between different workers, ensuring seamless orchestration.

This is achieved by registering all activities (`run_sky_launch`, `run_sky_down`, `run_sky_exec`)to the same worker and enqueueing them in the same task queue:

```python
async with Worker(
client,
task_queue='skypilot-task-queue',
workflows=[SkyPilotWorkflow],
activities=[run_sky_launch, run_sky_down, run_sky_exec, run_git_clone]
):
```

## Running the Workflow

1. If running temporal locally, start the Temporal server:
```bash
temporal server start-dev
```

2. Launch the workflow:
```bash
python skypilot_workflow.py
```

3. Monitor the workflow execution in the Temporal Web UI (typically http://localhost:8233).

<p align="center">
<img src="https://i.imgur.com/rxlO2pJ.png" width="800">
</p>


4. When the workflow completes, all logs will be available in the Temporal Web UI.

<p align="center">
<img src="https://i.imgur.com/h7LALlX.png" width="800">
</p>
241 changes: 241 additions & 0 deletions examples/temporal/skypilot_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import asyncio
from dataclasses import dataclass
from datetime import timedelta
import os
import subprocess

from temporalio import activity
from temporalio import workflow
from temporalio.client import Client
from temporalio.worker import Worker


@dataclass
class SkyLaunchCommand:
cluster_name: str
entrypoint: str
flags: str


@dataclass
class SkyDownCommand:
cluster_name: str


@dataclass
class SkyExecCommand:
cluster_name: str
entrypoint: str
flags: str


@activity.defn
async def run_sky_launch(input: SkyLaunchCommand) -> str:
activity.logger.info(
f'Running Sky Launch on cluster: {input.cluster_name} '
f'with entrypoint: {input.entrypoint} and flags: {input.flags}')

# Run the provided SkyPilot command using subprocess
command = f'sky launch -y -c {input.cluster_name} {input.flags} {input.entrypoint}'

try:
result = subprocess.run(command.split(),
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Sky launch output: {result.stdout}')
return result.stdout.strip() # Return the output from the subprocess
except subprocess.CalledProcessError as e:
activity.logger.error(f'Sky launch failed with error: {e}')
activity.logger.error(f'Stdout: {e.stdout}')
activity.logger.error(f'Stderr: {e.stderr}')
raise # Re-raise the exception to indicate failure


@activity.defn
async def run_sky_down(input: SkyDownCommand) -> str:
activity.logger.info(f'Running Sky Down on cluster: {input.cluster_name}')

# Run the sky down command using subprocess
command = f'sky down -y {input.cluster_name}'

try:
result = subprocess.run(command.split(),
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Sky down output: {result.stdout}')
return result.stdout.strip()
except subprocess.CalledProcessError as e:
activity.logger.error(f'Sky down failed with error: {e}')
activity.logger.error(f'Stdout: {e.stdout}')
activity.logger.error(f'Stderr: {e.stderr}')
raise # Re-raise the exception to indicate failure


@activity.defn
async def run_sky_exec(input: SkyExecCommand) -> str:
activity.logger.info(
f'Running Sky exec on cluster: {input.cluster_name} '
f'with entrypoint: {input.entrypoint} and flags: {input.flags}')

# Run the sky exec command using subprocess
full_command = f'sky exec {input.cluster_name} {input.flags} {input.entrypoint}'

try:
result = subprocess.run(full_command,
shell=True,
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Sky exec output: {result.stdout}')
return result.stdout.strip()
except subprocess.CalledProcessError as e:
activity.logger.error(f'Sky exec failed with error: {e}')
activity.logger.error(f'Stdout: {e.stdout}')
activity.logger.error(f'Stderr: {e.stderr}')
raise # Re-raise the exception to indicate failure


@dataclass
class GitCloneInput:
repo_url: str
clone_path: str


@activity.defn
async def run_git_clone(input: GitCloneInput) -> str:
activity.logger.info(
f'Cloning git repository: {input.repo_url} to {input.clone_path}')

# Create clone path if it doesn't exist
os.makedirs(input.clone_path, exist_ok=True)

# Check if the repository already exists
if os.path.exists(os.path.join(input.clone_path, '.git')):
# If it exists, pull the latest changes
command = f'git -C {input.clone_path} pull'
else:
# If it doesn't exist, clone the repository
command = f'git clone {input.repo_url} {input.clone_path}'

try:
result = subprocess.run(command.split(),
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Git clone output: {result.stdout}')
return result.stdout.strip()
except subprocess.CalledProcessError as e:
activity.logger.error(f'Git clone failed with error: {e}')
raise # Re-raise the exception to indicate failure


@dataclass
class SkyPilotWorkflowInput:
cluster_prefix: str
repo_url: str
data_bucket_url: str = None


@workflow.defn
class SkyPilotWorkflow:

@workflow.run
async def run(self, input: SkyPilotWorkflowInput) -> str:
cluster_prefix = input.cluster_prefix
repo_url = input.repo_url
data_bucket_url = input.data_bucket_url

workflow.logger.info(
f'Running SkyPilot workflow with cluster prefix: {cluster_prefix} ')

# 1. Clone the repository
clone_path = '/tmp/skypilot_repo'
clone_result = await workflow.execute_activity(
run_git_clone,
GitCloneInput(repo_url, clone_path),
start_to_close_timeout=timedelta(minutes=5),
)
workflow.logger.info(f'Clone result: {clone_result}')

data_bucket_flag = '--env DATA_BUCKET_URL=' + data_bucket_url if data_bucket_url else ''

# 2. Launch data preprocessing
cluster_name = f'{cluster_prefix}-preprocess'
preprocess_result = await workflow.execute_activity(
run_sky_launch,
SkyLaunchCommand(cluster_name,
f'{clone_path}/data_preprocessing.yaml',
f'--cloud kubernetes {data_bucket_flag}'),
start_to_close_timeout=timedelta(minutes=30),
)
workflow.logger.info(f'Preprocessing result: {preprocess_result}')

# 3. Down the cluster
down_result = await workflow.execute_activity(
run_sky_down,
SkyDownCommand(cluster_name),
start_to_close_timeout=timedelta(minutes=10),
)
workflow.logger.info(f'Down result: {down_result}')

# 4. Launch training
cluster_name = f'{cluster_prefix}-train'
train_result = await workflow.execute_activity(
run_sky_launch,
SkyLaunchCommand(cluster_name, f'{clone_path}/train.yaml',
f'--cloud kubernetes {data_bucket_flag}'),
start_to_close_timeout=timedelta(minutes=60),
)
workflow.logger.info(f'Training result: {train_result}')

# 5. Execute evaluation on the same
eval_result = await workflow.execute_activity(
run_sky_exec,
SkyExecCommand(cluster_name, f'{clone_path}/eval.yaml',
f'{data_bucket_flag}'),
start_to_close_timeout=timedelta(minutes=30),
)
workflow.logger.info(f'Evaluation result: {eval_result}')

# 6. Down the cluster
down_result = await workflow.execute_activity(
run_sky_down,
SkyDownCommand(cluster_name),
start_to_close_timeout=timedelta(minutes=10),
)
workflow.logger.info(f'Down result: {down_result}')

# Return the combined result
return f'Preprocessing: {preprocess_result}, Training: {train_result}, Evaluation: {eval_result}'


async def main():
# Start client
client = await Client.connect('localhost:7233')

# Run a worker for the workflow
async with Worker(
client,
task_queue='skypilot-task-queue',
workflows=[SkyPilotWorkflow],
activities=[run_sky_launch, run_sky_down, run_sky_exec, run_git_clone
], # Register all Sky activities to the same worker
):
# Execute the workflow with cluster name and config path
result = await client.execute_workflow(
SkyPilotWorkflow.run,
SkyPilotWorkflowInput(
cluster_prefix='my-workflow', # cluster name prefix
repo_url=
'https://github.com/romilbhardwaj/mock_train_workflow.git',
data_bucket_url='gs://sky-example-data'), # repo url
id='skypilot-workflow-id',
task_queue='skypilot-task-queue',
)
print(f'SkyPilot Workflow Result: {result}')


if __name__ == '__main__':
asyncio.run(main())
Loading