Skip to content

Commit

Permalink
more adaptive scaling fixes (#97)
Browse files Browse the repository at this point in the history
more adaptive sacling fixe from #63
  • Loading branch information
Joe Hamman authored and guillaumeeb committed Oct 7, 2018
1 parent 3204da4 commit ee6e79e
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 58 deletions.
2 changes: 1 addition & 1 deletion ci/pbs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ RUN git clone --branch v14.1.2 https://github.com/pbspro/pbspro.git /src/pbspro
COPY build.sh /
RUN bash /build.sh

# base image
# base image
FROM centos:7.4.1708
LABEL description="PBS Professional Open Source and conda"

Expand Down
85 changes: 61 additions & 24 deletions dask_jobqueue/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ def __init__(self):

def add_worker(self, scheduler, worker=None, name=None, **kwargs):
''' Run when a new worker enters the cluster'''
logger.debug("adding worker %s" % worker)
logger.debug("adding worker %s", worker)
w = scheduler.workers[worker]
job_id = _job_id_from_worker_name(w.name)
logger.debug("job id for new worker: %s" % job_id)
logger.debug("job id for new worker: %s", job_id)
self.all_workers[worker] = (w.name, job_id)

# if this is the first worker for this job, move job to running
if job_id not in self.running_jobs:
logger.debug("this is a new job or restarting worker")
logger.debug("%s is a new job or restarting worker", job_id)
if job_id in self.pending_jobs:
logger.debug("%s is a new job, adding to running_jobs", job_id)
self.running_jobs[job_id] = self.pending_jobs.pop(job_id)
elif job_id in self.finished_jobs:
logger.warning('Worker %s restart in Job %s. '
Expand All @@ -74,16 +75,16 @@ def add_worker(self, scheduler, worker=None, name=None, **kwargs):

def remove_worker(self, scheduler=None, worker=None, **kwargs):
''' Run when a worker leaves the cluster'''
logger.debug("removing worker %s" % worker)
logger.debug("removing worker %s", worker)
name, job_id = self.all_workers[worker]
logger.debug("removing worker name (%s) and job_id (%s)" % (name, job_id))
logger.debug("removing worker name (%s) and job_id (%s)", name, job_id)

# remove worker from this job
del self.running_jobs[job_id][name]
self.running_jobs[job_id].pop(name, None)

# once there are no more workers, move this job to finished_jobs
if not self.running_jobs[job_id]:
logger.debug("that was the last worker for job %s" % job_id)
logger.debug("that was the last worker for job %s", job_id)
self.finished_jobs[job_id] = self.running_jobs.pop(job_id)


Expand Down Expand Up @@ -222,8 +223,7 @@ def __init__(self,
self.local_cluster = LocalCluster(n_workers=0, diagnostics_port=diagnostics_ip_and_port,
**kwargs)

# Keep information on process, threads and memory, for use in
# subclasses
# Keep information on process, cores, and memory, for use in subclasses
self.worker_memory = parse_bytes(memory) if memory is not None else None
self.worker_processes = processes
self.worker_cores = cores
Expand Down Expand Up @@ -257,8 +257,10 @@ def __init__(self,

self._command_template = ' '.join(map(str, command_args))

self._target_scale = 0

def __repr__(self):
running_workers = sum(len(value) for value in self.running_jobs.values())
running_workers = self._count_active_workers()
running_cores = running_workers * self.worker_threads
total_jobs = len(self.pending_jobs) + len(self.running_jobs)
total_workers = total_jobs * self.worker_processes
Expand Down Expand Up @@ -301,7 +303,7 @@ def job_file(self):
""" Write job submission script to temporary file """
with tmpfile(extension='sh') as fn:
with open(fn, 'w') as f:
logger.debug("writing job script: \n%s" % self.job_script())
logger.debug("writing job script: \n%s", self.job_script())
f.write(self.job_script())
yield fn

Expand All @@ -310,13 +312,15 @@ def _submit_job(self, script_filename):

def start_workers(self, n=1):
""" Start workers and point them to our local scheduler """
logger.debug('starting %s workers' % n)
logger.debug('starting %s workers', n)
num_jobs = int(math.ceil(n / self.worker_processes))
for _ in range(num_jobs):
with self.job_file() as fn:
out = self._submit_job(fn)
job = self._job_id_from_submit_output(out.decode())
logger.debug("started job: %s" % job)
if not job:
raise ValueError('Unable to parse jobid from output of %s' % out)
logger.debug("started job: %s", job)
self.pending_jobs[job] = {}

@property
Expand Down Expand Up @@ -354,7 +358,7 @@ def _calls(self, cmds, **kwargs):
for proc in procs:
out, err = proc.communicate()
if err:
logger.error(err.decode())
raise RuntimeError(err.decode())
result.append(out)
return result

Expand All @@ -364,7 +368,7 @@ def _call(self, cmd, **kwargs):

def stop_workers(self, workers):
""" Stop a list of workers"""
logger.debug("Stopping workers: %s" % workers)
logger.debug("Stopping workers: %s", workers)
if not workers:
return
jobs = self._del_pending_jobs() # stop pending jobs too
Expand All @@ -373,33 +377,66 @@ def stop_workers(self, workers):
jobs.append(_job_id_from_worker_name(w['name']))
else:
jobs.append(_job_id_from_worker_name(w.name))
self.stop_jobs(set(jobs))
self.stop_jobs(jobs)

def stop_jobs(self, jobs):
""" Stop a list of jobs"""
logger.debug("Stopping jobs: %s" % jobs)
logger.debug("Stopping jobs: %s", jobs)
if jobs:
jobs = list(jobs)
self._call([self.cancel_command] + list(set(jobs)))

# if any of these jobs were pending, we should remove those now
for job_id in jobs:
if job_id in self.pending_jobs:
del self.pending_jobs[job_id]

def scale_up(self, n, **kwargs):
""" Brings total worker count up to ``n`` """
logger.debug("Scaling up to %d workers." % n)
active_and_pending = sum([len(j) for j in self.running_jobs.values()])
active_and_pending += self.worker_processes * len(self.pending_jobs)
logger.debug("Found %d active/pending workers." % active_and_pending)
self.start_workers(n - active_and_pending)
active_and_pending = self._count_active_and_pending_workers()
if n >= active_and_pending:
logger.debug("Scaling up to %d workers.", n)
self.start_workers(n - self._count_active_and_pending_workers())
else:
n_to_close = active_and_pending - n
if n_to_close < self._count_pending_workers():
# We only need to kill some pending jobs, this is actually a
# scale down bu could not be handled upstream
to_kill = int(n_to_close / self.worker_processes)
jobs = list(self.pending_jobs.keys())[to_kill:]
self.stop_jobs(jobs)
else:
# We should not end here, a new scale call should not begin
# until a scale_up or scale_down has ended
raise RuntimeError('JobQueueCluster.scale_up was called with'
' a number of worker lower than the '
'currently connected workers')

def _count_active_and_pending_workers(self):
active_and_pending = (self._count_active_workers() +
self._count_pending_workers())
logger.debug("Found %d active/pending workers.", active_and_pending)
assert len(self.scheduler.workers) <= active_and_pending
return active_and_pending

def _count_active_workers(self):
active_workers = sum([len(j) for j in self.running_jobs.values()])
assert len(self.scheduler.workers) == active_workers
return active_workers

def _count_pending_workers(self):
return self.worker_processes * len(self.pending_jobs)

def scale_down(self, workers):
''' Close the workers with the given addresses '''
logger.debug("Scaling down. Workers: %s" % workers)
logger.debug("Scaling down. Workers: %s", workers)
worker_states = []
for w in workers:
try:
# Get the actual WorkerState
worker_states.append(self.scheduler.workers[w])
except KeyError:
logger.debug('worker %s is already gone' % w)
logger.debug('worker %s is already gone', w)
self.stop_workers(worker_states)

def stop_all_jobs(self):
Expand Down
2 changes: 1 addition & 1 deletion dask_jobqueue/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import, division, print_function

QUEUE_WAIT = 60 # seconds
QUEUE_WAIT = 15 # seconds
41 changes: 26 additions & 15 deletions dask_jobqueue/tests/test_pbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ def test_basic(loop):
with PBSCluster(walltime='00:02:00', processes=1, cores=2, memory='2GB', local_directory='/tmp',
job_extra=['-V'], loop=loop) as cluster:
with Client(cluster) as client:
cluster.start_workers(2)
assert cluster.pending_jobs or cluster.running_jobs

cluster.scale(2)

start = time()
while not(cluster.pending_jobs or cluster.running_jobs):
sleep(0.100)
assert time() < start + QUEUE_WAIT

future = client.submit(lambda x: x + 1, 10)
assert future.result(QUEUE_WAIT) == 11
assert cluster.running_jobs
Expand All @@ -102,29 +108,38 @@ def test_basic(loop):
assert w['memory_limit'] == 2e9
assert w['ncores'] == 2

cluster.stop_workers(workers)
cluster.scale(0)

start = time()
while client.scheduler_info()['workers']:
while cluster.running_jobs:
sleep(0.100)
assert time() < start + QUEUE_WAIT

assert not cluster.running_jobs


@pytest.mark.env("pbs") # noqa: F811
@pytest.mark.skip(reason="Current scale method not capable of doing this")
def test_basic_scale_edge_cases(loop):
with PBSCluster(walltime='00:02:00', processes=1, cores=2, memory='2GB', local_directory='/tmp',
job_extra=['-V'], loop=loop) as cluster:

cluster.scale(2)
cluster.scale(0)

# Wait to see what happens
sleep(0.2)

assert not(cluster.pending_jobs or cluster.running_jobs)


@pytest.mark.env("pbs") # noqa: F811
def test_adaptive(loop):
with PBSCluster(walltime='00:02:00', processes=1, cores=2, memory='2GB', local_directory='/tmp',
job_extra=['-V'], loop=loop) as cluster:
cluster.adapt()
with Client(cluster) as client:
future = client.submit(lambda x: x + 1, 10)

start = time()
while not (cluster.pending_jobs or cluster.running_jobs):
sleep(0.100)
assert time() < start + QUEUE_WAIT

assert future.result(QUEUE_WAIT) == 11

start = time()
Expand All @@ -135,15 +150,11 @@ def test_adaptive(loop):

del future

start = time()
while len(client.scheduler_info()['workers']) > 0:
sleep(0.100)
assert time() < start + QUEUE_WAIT

start = time()
while cluster.pending_jobs or cluster.running_jobs:
sleep(0.100)
assert time() < start + QUEUE_WAIT

assert cluster.finished_jobs


Expand Down
15 changes: 9 additions & 6 deletions dask_jobqueue/tests/test_sge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ def test_basic(loop): # noqa: F811
with SGECluster(walltime='00:02:00', cores=8, processes=4, memory='2GB', loop=loop) as cluster:
print(cluster.job_script())
with Client(cluster, loop=loop) as client:
cluster.start_workers(2)
assert cluster.pending_jobs or cluster.running_jobs

cluster.scale(2)

start = time()
while not(cluster.pending_jobs or cluster.running_jobs):
sleep(0.100)
assert time() < start + QUEUE_WAIT

future = client.submit(lambda x: x + 1, 10)
assert future.result(QUEUE_WAIT) == 11
Expand All @@ -28,11 +33,9 @@ def test_basic(loop): # noqa: F811
assert w['memory_limit'] == 2e9 / 4
assert w['ncores'] == 2

cluster.stop_workers(workers)
cluster.scale(0)

start = time()
while client.scheduler_info()['workers']:
while cluster.running_jobs:
sleep(0.100)
assert time() < start + QUEUE_WAIT

assert not cluster.running_jobs
22 changes: 11 additions & 11 deletions dask_jobqueue/tests/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ def test_basic(loop):
with SLURMCluster(walltime='00:02:00', cores=2, processes=1, memory='2GB',
job_extra=['-D /'], loop=loop) as cluster:
with Client(cluster) as client:
cluster.start_workers(2)
assert cluster.pending_jobs or cluster.running_jobs

cluster.scale(2)

start = time()
while not(cluster.pending_jobs or cluster.running_jobs):
sleep(0.100)
assert time() < start + QUEUE_WAIT

future = client.submit(lambda x: x + 1, 10)
assert future.result(QUEUE_WAIT) == 11
assert cluster.running_jobs
Expand All @@ -102,15 +108,13 @@ def test_basic(loop):
assert w['memory_limit'] == 2e9
assert w['ncores'] == 2

cluster.stop_workers(workers)
cluster.scale(0)

start = time()
while client.scheduler_info()['workers']:
while cluster.running_jobs:
sleep(0.100)
assert time() < start + QUEUE_WAIT

assert not cluster.running_jobs


@pytest.mark.env("slurm") # noqa: F811
def test_adaptive(loop):
Expand All @@ -136,12 +140,8 @@ def test_adaptive(loop):
del future

start = time()
while len(client.scheduler_info()['workers']) > 0:
while cluster.running_jobs:
sleep(0.100)
assert time() < start + QUEUE_WAIT

start = time()
while cluster.pending_jobs or cluster.running_jobs:
sleep(0.100)
assert time() < start + QUEUE_WAIT
assert cluster.finished_jobs

0 comments on commit ee6e79e

Please sign in to comment.