Skip to content

Commit

Permalink
Merge pull request spotify#1 from markgaynor/slurm
Browse files Browse the repository at this point in the history
Now works with Python 3
  • Loading branch information
jcftang authored Dec 18, 2017
2 parents 2b19c91 + 0119727 commit 8d88f6d
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions luigi/contrib/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.
#

"""Slurm batch system Tasks.
"""
Slurm batch system Tasks.
Adapted by Jimmy Tang <jtang@voysis.com> from the sge.py by Jake Feala (@jfeala)
Expand Down Expand Up @@ -76,10 +77,8 @@ def output(self):
shared-tmp-dir = /home
ntasks = 2
"""


# This extension is modeled after the hadoop.py approach.
#
# Implementation notes
Expand Down Expand Up @@ -110,7 +109,6 @@ def output(self):

POLL_TIME = 5 # decided to hard-code rather than configure here


def _parse_job_state(job_out):
"""Parse "state" from 'scontrol show jobid=ID -o' output
Expand All @@ -122,17 +120,14 @@ def _parse_job_state(job_out):
job_map = {}
for job in job_line:
job_s = job.split("=")
job_map[job_s[0]] = job_s[1]
try:
job_map[job_s[0]] = job_s[1]
except:
print("No value found for " + job_s[0])

return job_map.get('JobState', 'u')


def _parse_job_id(job_id):
"""Parse job id from scontrol output string.
"""
return int(job_id)


def _build_submit_command(cmd, job_name, outfile, errfile, ntasks, mem, gres, partition, time, sbatchfile):
"""Submit shell command to Slurm, queue via `sbatch`"""
sbatch_template = """#!/bin/bash
Expand Down Expand Up @@ -160,10 +155,10 @@ def _build_submit_command(cmd, job_name, outfile, errfile, ntasks, mem, gres, pa
sbatch_template=sbatch_template, job_name=job_name, outfile=outfile, errfile=errfile,
ntasks=ntasks, mem=mem, sbatchfile=sbatchfile, gres=gres, partition=partition, time=time)


class SlurmJobTask(luigi.Task):

"""Base class for executing a job on Slurm
"""
Base class for executing a job on Slurm
Override ``work()`` (rather than ``run()``) with your job code.
Expand Down Expand Up @@ -241,7 +236,6 @@ def _fetch_task_failures(self):
return errors

def _init_local(self):

# Set up temp folder in shared directory (trim to max filename length)
base_tmp_dir = self.shared_tmp_dir
random_id = '%016x' % random.getrandbits(64)
Expand Down Expand Up @@ -287,15 +281,14 @@ def _dump(self, out_dir=''):
with self.no_unpicklable_properties():
self.job_file = os.path.join(out_dir, 'job-instance.pickle')
if self.__module__ == '__main__':
d = pickle.dumps(self)
d = pickle.dumps(self, 0).decode()
module_name = os.path.basename(sys.argv[0]).rsplit('.', 1)[0]
d = d.replace('(c__main__', "(c" + module_name)
open(self.job_file, "w").write(d)
else:
pickle.dump(self, open(self.job_file, "w"))

def _run_job(self):

# Build a sbatch argument that will run sge_runner.py on the directory we've specified
runner_path = sge_runner.__file__
if runner_path.endswith("pyc"):
Expand All @@ -319,8 +312,8 @@ def _run_job(self):
os.chdir(self.tmp_dir)
output = subprocess.check_output(submit_cmd, shell=True)
os.chdir(cwd)
self.job_id = output.strip()
logger.debug("Submitted job to slurm with job id: {}".format(output))
self.job_id = output.decode().strip()
logger.debug("Submitted job to slurm with job id: {}".format(self.job_id))

self._track_job()

Expand All @@ -331,18 +324,19 @@ def _run_job(self):
shutil.rmtree(self.tmp_dir)

def _track_job(self):
start = time.time()
while True:
# Sleep for a little bit
time.sleep(self.poll_time)

# See what the job's up to
# ASSUMPTION
job_stat_out = subprocess.check_output(['scontrol', '-o', 'show', "jobid={}".format(self.job_id)])
job_stat_out = subprocess.check_output(['scontrol', '-o', 'show', "jobid={}".format(self.job_id)]).decode()
job_status = _parse_job_state(job_stat_out)
if job_status == 'RUNNING':
logger.info('Job is running...')
logger.info('Job is running ({:0.1f} seconds elapsed)...'.format(float(time.time() - start)))
elif job_status == 'PENDING':
logger.info('Job is pending...')
logger.info('Job is pending ({:0.1f} seconds elapsed)...'.format(float(time.time() - start)))
elif 'FAILED' in job_status:
logger.error('Job has FAILED:\n' + '\n'.join(self._fetch_task_failures()))
break
Expand Down

0 comments on commit 8d88f6d

Please sign in to comment.