Skip to content

Commit

Permalink
refactored data sampling to be better pluggable
Browse files Browse the repository at this point in the history
  • Loading branch information
ArchieGertsman committed Dec 31, 2023
1 parent f92ed55 commit f943d4d
Show file tree
Hide file tree
Showing 19 changed files with 454 additions and 718 deletions.
Binary file removed .DS_Store
Binary file not shown.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ __pycache__/
.pytest_cache/
.ipynb_checkpoints/
*.egg-info/
.DS_Store

artifacts/
data/
count_params.py
eval*
config/test.yaml
config/test.yaml
test/old_test.py
7 changes: 1 addition & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ Enhancements include:

---

After cloning this repo, please run `pip install -r requirements.txt` to install the project's dependencies. Then, please manually install `torch_scatter` and `torch_sparse` by running e.g.
```
pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html
pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html
```
They are commented out in `requirements.txt` because `torch` needs to be installed first. See [here](https://github.com/pyg-team/pytorch_geometric/issues/861) for more.
After cloning this repo, please run `pip install -r requirements.txt` to install the project's dependencies.

To start out, try running examples via `examples.py --sched [fair|decima]`. To train Decima from scratch, modify the provided config file `config/decima_tpch.yaml` as needed, then provide the config to `train.py -f CFG_FILE`.
93 changes: 44 additions & 49 deletions examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''Examples of how to run job scheduling simulations with different schedulers
'''
"""Examples of how to run job scheduling simulations with different schedulers
"""
import os.path as osp
from pprint import pprint

Expand All @@ -13,89 +13,84 @@
from spark_sched_sim import metrics


ENV_KWARGS = {
'num_executors': 10,
'job_arrival_cap': 50,
'job_arrival_rate': 4.e-5,
'moving_delay': 2000.,
'warmup_delay': 1000.,
'dataset': 'tpch',
'render_mode': 'human'
ENV_CFG = {
"num_executors": 10,
"job_arrival_cap": 50,
"job_arrival_rate": 4.0e-5,
"moving_delay": 2000.0,
"warmup_delay": 1000.0,
"data_sampler_cls": "TPCHDataSampler",
"render_mode": "human",
}


def main():
# save final rendering to artifacts dir
pathlib.Path('artifacts').mkdir(parents=True, exist_ok=True)
pathlib.Path("artifacts").mkdir(parents=True, exist_ok=True)

parser = ArgumentParser(
description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter)

description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
)

parser.add_argument(
'--sched',
choices=['fair', 'decima'],
dest='sched',
help='which scheduler to run',
"--sched",
choices=["fair", "decima"],
dest="sched",
help="which scheduler to run",
required=True,
)

args = parser.parse_args()

sched_map = {
'fair': fair_example,
'decima': decima_example
}
sched_map = {"fair": fair_example, "decima": decima_example}

sched_map[args.sched]()



def fair_example():
# Fair scheduler
scheduler = RoundRobinScheduler(ENV_KWARGS['num_executors'],
dynamic_partition=True)

print(f'Example: Fair Scheduler')
print('Env settings:')
pprint(ENV_KWARGS)
scheduler = RoundRobinScheduler(ENV_CFG["num_executors"], dynamic_partition=True)

print('Running episode...')
avg_job_duration = run_episode(ENV_KWARGS, scheduler)
print(f"Example: Fair Scheduler")
print("Env settings:")
pprint(ENV_CFG)

print(f'Done! Average job duration: {avg_job_duration:.1f}s', flush=True)
print()
print("Running episode...")
avg_job_duration = run_episode(ENV_CFG, scheduler)

print(f"Done! Average job duration: {avg_job_duration:.1f}s", flush=True)
print()


def decima_example():
cfg = load(filename=osp.join('config', 'decima_tpch.yaml'))
cfg = load(filename=osp.join("config", "decima_tpch.yaml"))

agent_cfg = cfg['agent'] \
| {'num_executors': ENV_KWARGS['num_executors'],
'state_dict_path': osp.join('models', 'decima', 'model.pt')}

scheduler = make_scheduler(agent_cfg)
agent_cfg = cfg["agent"] | {
"num_executors": ENV_CFG["num_executors"],
"state_dict_path": osp.join("models", "decima", "model.pt"),
}

print(f'Example: Decima')
print('Env settings:')
pprint(ENV_KWARGS)
scheduler = make_scheduler(agent_cfg)

print('Running episode...')
avg_job_duration = run_episode(ENV_KWARGS, scheduler)
print(f"Example: Decima")
print("Env settings:")
pprint(ENV_CFG)

print(f'Done! Average job duration: {avg_job_duration:.1f}s', flush=True)
print("Running episode...")
avg_job_duration = run_episode(ENV_CFG, scheduler)

print(f"Done! Average job duration: {avg_job_duration:.1f}s", flush=True)


def run_episode(env_kwargs, scheduler, seed=1234):
env = gym.make('spark_sched_sim:SparkSchedSimEnv-v0', **env_kwargs)
def run_episode(env_cfg, scheduler, seed=1234):
env = gym.make("spark_sched_sim:SparkSchedSimEnv-v0", env_cfg=env_cfg)
if isinstance(scheduler, NeuralScheduler):
env = NeuralActWrapper(env)
env = scheduler.obs_wrapper_cls(env)

obs, _ = env.reset(seed=seed, options=None)
terminated = truncated = False

while not (terminated or truncated):
if isinstance(scheduler, NeuralScheduler):
action, *_ = scheduler(obs)
Expand All @@ -111,5 +106,5 @@ def run_episode(env_kwargs, scheduler, seed=1234):
return avg_job_duration


if __name__ == '__main__':
main()
if __name__ == "__main__":
main()
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ tensorboard-data-server==0.7.1
threadpoolctl==3.2.0
torch==2.0.1
torch-geometric==2.3.1
# torch-scatter==2.1.1
# torch-sparse==0.6.17
-f https://download.pytorch.org/whl/cpu/torch_stable.html
-f https://data.pyg.org/whl/torch-2.0.1+cpu.html
torch-scatter==2.1.1
torch-sparse==0.6.17
torchaudio==2.0.2
torchvision==0.15.2
tqdm==4.66.1
Expand Down
Binary file removed screenshot.png
Binary file not shown.
68 changes: 12 additions & 56 deletions spark_sched_sim/components/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@


class Job:
'''An object representing a job in the system, containing a set of stages with dependencies stored in a dag.'''
"""An object representing a job in the system, containing a set of stages with dependencies stored in a dag."""

def __init__(
self,
id_: int,
stages: list[Stage],
dag: nx.DiGraph,
t_arrival: float,
query_size: int,
query_num: int
self, id_: int, stages: list[Stage], dag: nx.DiGraph, t_arrival: float
):
# unique identifier of this job
self.id_ = id_
Expand All @@ -27,7 +21,7 @@ def __init__(
# all incomplete stages
# TODO: use ordered set
self.active_stages = stages.copy()

# incomplete stages whose parents have completed
self.frontier_stages = set()

Expand All @@ -37,10 +31,6 @@ def __init__(
# time that this job arrived into the system
self.t_arrival = t_arrival

self.query_size = query_size

self.query_num = query_num

# time that this job completed, i.e. when the last
# stage completed
self.t_completed = np.inf
Expand All @@ -53,44 +43,28 @@ def __init__(

self.init_frontier()


def __str__(self):
return f'TPCH_{self.query_num}_{self.query_size}'



@property
def pool_key(self):
return (self.id_, None)



@property
def completed(self):
return self.num_active_stages == 0



@property
def saturated(self):
return self.saturated_stage_count == len(self.stages)



@property
def num_stages(self):
return len(self.stages)



@property
def num_active_stages(self):
return len(self.active_stages)



def add_stage_completion(self, stage):
'''increments the count of completed stages'''
"""increments the count of completed stages"""
self.active_stages.remove(stage)

self.frontier_stages.remove(stage)
Expand All @@ -99,41 +73,29 @@ def add_stage_completion(self, stage):
self.frontier_stages |= new_stages

return len(new_stages) > 0



def init_frontier(self):
'''returns a set containing all the stages which are
"""returns a set containing all the stages which are
source nodes in the dag, i.e. which have no dependencies
'''
"""
assert len(self.frontier_stages) == 0
self.frontier_stages |= self.source_stages()



def source_stages(self):
return set(
self.stages[node]
for node, in_deg in self.dag.in_degree()
if in_deg == 0
self.stages[node] for node, in_deg in self.dag.in_degree() if in_deg == 0
)



def children_stages(self, stage):
return (self.stages[stage_id] for stage_id in self.dag.successors(stage.id_))



def parent_stages(self, stage):
return (self.stages[stage_id] for stage_id in self.dag.predecessors(stage.id_))



def find_new_frontier_stages(self, stage):
'''if ` stage` is completed, returns all of its successors whose other dependencies are also
"""if ` stage` is completed, returns all of its successors whose other dependencies are also
completed, if any exist.
'''
"""
if not stage.completed:
return set()

Expand All @@ -144,29 +106,23 @@ def find_new_frontier_stages(self, stage):
new_stage = self.stages[suc_stage_id]
if not new_stage.completed and self.check_dependencies(suc_stage_id):
new_stages.add(new_stage)

return new_stages


return new_stages

def check_dependencies(self, stage_id):
'''searches to see if all the dependencies of stage with id `stage_id` are satisfied.'''
"""searches to see if all the dependencies of stage with id `stage_id` are satisfied."""
for dep_id in self.dag.predecessors(stage_id):
if not self.stages[dep_id].completed:
return False

return True



def add_local_executor(self, executor):
assert executor.task is None
self.local_executors.add(executor.id_)
executor.job_id = self.id_



def remove_local_executor(self, executor):
self.local_executors.remove(executor.id_)
executor.job_id = None
executor.task = None
executor.task = None
Loading

0 comments on commit f943d4d

Please sign in to comment.