| 
2 | 2 | import asyncio  | 
3 | 3 | import sys  | 
4 | 4 | import json  | 
 | 5 | +import os  | 
5 | 6 | import re  | 
6 | 7 | from tempfile import gettempdir  | 
7 | 8 | from pathlib import Path  | 
@@ -186,6 +187,173 @@ def close(self):  | 
186 | 187 |         self.pool.shutdown()  | 
187 | 188 | 
 
  | 
188 | 189 | 
 
  | 
 | 190 | +class OarWorker(DistributedWorker):  | 
 | 191 | +    """A worker to execute tasks on OAR systems."""  | 
 | 192 | + | 
 | 193 | +    _cmd = "oarsub"  | 
 | 194 | + | 
 | 195 | +    def __init__(self, loop=None, max_jobs=None, poll_delay=1, oarsub_args=None):  | 
 | 196 | +        """  | 
 | 197 | +        Initialize OAR Worker.  | 
 | 198 | +
  | 
 | 199 | +        Parameters  | 
 | 200 | +        ----------  | 
 | 201 | +        poll_delay : seconds  | 
 | 202 | +            Delay between polls to oar  | 
 | 203 | +        oarsub_args : str  | 
 | 204 | +            Additional oarsub arguments  | 
 | 205 | +        max_jobs : int  | 
 | 206 | +            Maximum number of submitted jobs  | 
 | 207 | +
  | 
 | 208 | +        """  | 
 | 209 | +        super().__init__(loop=loop, max_jobs=max_jobs)  | 
 | 210 | +        if not poll_delay or poll_delay < 0:  | 
 | 211 | +            poll_delay = 0  | 
 | 212 | +        self.poll_delay = poll_delay  | 
 | 213 | +        self.oarsub_args = oarsub_args or ""  | 
 | 214 | +        self.error = {}  | 
 | 215 | + | 
 | 216 | +    def run_el(self, runnable, rerun=False):  | 
 | 217 | +        """Worker submission API."""  | 
 | 218 | +        script_dir, batch_script = self._prepare_runscripts(runnable, rerun=rerun)  | 
 | 219 | +        if (script_dir / script_dir.parts[1]) == gettempdir():  | 
 | 220 | +            logger.warning("Temporary directories may not be shared across computers")  | 
 | 221 | +        if isinstance(runnable, TaskBase):  | 
 | 222 | +            cache_dir = runnable.cache_dir  | 
 | 223 | +            name = runnable.name  | 
 | 224 | +            uid = runnable.uid  | 
 | 225 | +        else:  # runnable is a tuple (ind, pkl file, task)  | 
 | 226 | +            cache_dir = runnable[-1].cache_dir  | 
 | 227 | +            name = runnable[-1].name  | 
 | 228 | +            uid = f"{runnable[-1].uid}_{runnable[0]}"  | 
 | 229 | + | 
 | 230 | +        return self._submit_job(batch_script, name=name, uid=uid, cache_dir=cache_dir)  | 
 | 231 | + | 
 | 232 | +    def _prepare_runscripts(self, task, interpreter="/bin/sh", rerun=False):  | 
 | 233 | +        if isinstance(task, TaskBase):  | 
 | 234 | +            cache_dir = task.cache_dir  | 
 | 235 | +            ind = None  | 
 | 236 | +            uid = task.uid  | 
 | 237 | +        else:  | 
 | 238 | +            ind = task[0]  | 
 | 239 | +            cache_dir = task[-1].cache_dir  | 
 | 240 | +            uid = f"{task[-1].uid}_{ind}"  | 
 | 241 | + | 
 | 242 | +        script_dir = cache_dir / f"{self.__class__.__name__}_scripts" / uid  | 
 | 243 | +        script_dir.mkdir(parents=True, exist_ok=True)  | 
 | 244 | +        if ind is None:  | 
 | 245 | +            if not (script_dir / "_task.pkl").exists():  | 
 | 246 | +                save(script_dir, task=task)  | 
 | 247 | +        else:  | 
 | 248 | +            copyfile(task[1], script_dir / "_task.pklz")  | 
 | 249 | + | 
 | 250 | +        task_pkl = script_dir / "_task.pklz"  | 
 | 251 | +        if not task_pkl.exists() or not task_pkl.stat().st_size:  | 
 | 252 | +            raise Exception("Missing or empty task!")  | 
 | 253 | + | 
 | 254 | +        batchscript = script_dir / f"batchscript_{uid}.sh"  | 
 | 255 | +        python_string = (  | 
 | 256 | +            f"""'from pydra.engine.helpers import load_and_run; """  | 
 | 257 | +            f"""load_and_run(task_pkl="{task_pkl}", ind={ind}, rerun={rerun}) '"""  | 
 | 258 | +        )  | 
 | 259 | +        bcmd = "\n".join(  | 
 | 260 | +            (  | 
 | 261 | +                f"#!{interpreter}",  | 
 | 262 | +                f"{sys.executable} -c " + python_string,  | 
 | 263 | +            )  | 
 | 264 | +        )  | 
 | 265 | +        with batchscript.open("wt") as fp:  | 
 | 266 | +            fp.writelines(bcmd)  | 
 | 267 | +        os.chmod(batchscript, 0o544)  | 
 | 268 | +        return script_dir, batchscript  | 
 | 269 | + | 
 | 270 | +    async def _submit_job(self, batchscript, name, uid, cache_dir):  | 
 | 271 | +        """Coroutine that submits task runscript and polls job until completion or error."""  | 
 | 272 | +        script_dir = cache_dir / f"{self.__class__.__name__}_scripts" / uid  | 
 | 273 | +        sargs = self.oarsub_args.split()  | 
 | 274 | +        jobname = re.search(r"(?<=-n )\S+|(?<=--name=)\S+", self.oarsub_args)  | 
 | 275 | +        if not jobname:  | 
 | 276 | +            jobname = ".".join((name, uid))  | 
 | 277 | +            sargs.append(f"--name={jobname}")  | 
 | 278 | +        output = re.search(r"(?<=-O )\S+|(?<=--stdout=)\S+", self.oarsub_args)  | 
 | 279 | +        if not output:  | 
 | 280 | +            output_file = str(script_dir / "oar-%jobid%.out")  | 
 | 281 | +            sargs.append(f"--stdout={output_file}")  | 
 | 282 | +        error = re.search(r"(?<=-E )\S+|(?<=--stderr=)\S+", self.oarsub_args)  | 
 | 283 | +        if not error:  | 
 | 284 | +            error_file = str(script_dir / "oar-%jobid%.err")  | 
 | 285 | +            sargs.append(f"--stderr={error_file}")  | 
 | 286 | +        else:  | 
 | 287 | +            error_file = None  | 
 | 288 | +        sargs.append(str(batchscript))  | 
 | 289 | +        # TO CONSIDER: add random sleep to avoid overloading calls  | 
 | 290 | +        logger.debug(f"Submitting job {' '.join(sargs)}")  | 
 | 291 | +        rc, stdout, stderr = await read_and_display_async(  | 
 | 292 | +            self._cmd, *sargs, hide_display=True  | 
 | 293 | +        )  | 
 | 294 | +        jobid = re.search(r"OAR_JOB_ID=(\d+)", stdout)  | 
 | 295 | +        if rc:  | 
 | 296 | +            raise RuntimeError(f"Error returned from oarsub: {stderr}")  | 
 | 297 | +        elif not jobid:  | 
 | 298 | +            raise RuntimeError("Could not extract job ID")  | 
 | 299 | +        jobid = jobid.group(1)  | 
 | 300 | +        if error_file:  | 
 | 301 | +            error_file = error_file.replace("%jobid%", jobid)  | 
 | 302 | +        self.error[jobid] = error_file.replace("%jobid%", jobid)  | 
 | 303 | +        # intermittent polling  | 
 | 304 | +        while True:  | 
 | 305 | +            # 4 possibilities  | 
 | 306 | +            # False: job is still pending/working  | 
 | 307 | +            # Terminated: job is complete  | 
 | 308 | +            # Error + idempotent: job has been stopped and resubmited with another jobid  | 
 | 309 | +            # Error: Job failure  | 
 | 310 | +            done = await self._poll_job(jobid)  | 
 | 311 | +            if not done:  | 
 | 312 | +                await asyncio.sleep(self.poll_delay)  | 
 | 313 | +            elif done == "Terminated":  | 
 | 314 | +                return True  | 
 | 315 | +            elif done == "Error" and "idempotent" in self.oarsub_args:  | 
 | 316 | +                logger.debug(  | 
 | 317 | +                    f"Job {jobid} has been stopped. Looking for its resubmission..."  | 
 | 318 | +                )  | 
 | 319 | +                # loading info about task with a specific uid  | 
 | 320 | +                info_file = cache_dir / f"{uid}_info.json"  | 
 | 321 | +                if info_file.exists():  | 
 | 322 | +                    checksum = json.loads(info_file.read_text())["checksum"]  | 
 | 323 | +                    if (cache_dir / f"{checksum}.lock").exists():  | 
 | 324 | +                        # for pyt3.8 we could you missing_ok=True  | 
 | 325 | +                        (cache_dir / f"{checksum}.lock").unlink()  | 
 | 326 | +                cmd_re = ("oarstat", "-J", "--sql", f"resubmit_job_id='{jobid}'")  | 
 | 327 | +                _, stdout, _ = await read_and_display_async(*cmd_re, hide_display=True)  | 
 | 328 | +                if not stdout:  | 
 | 329 | +                    raise RuntimeError(  | 
 | 330 | +                        "Job information about resubmission of job {jobid} not found"  | 
 | 331 | +                    )  | 
 | 332 | +                jobid = next(iter(json.loads(stdout).keys()), None)  | 
 | 333 | +            else:  | 
 | 334 | +                error_file = self.error[jobid]  | 
 | 335 | +                error_line = Path(error_file).read_text().split("\n")[-2]  | 
 | 336 | +                if "Exception" in error_line:  | 
 | 337 | +                    error_message = error_line.replace("Exception: ", "")  | 
 | 338 | +                elif "Error" in error_line:  | 
 | 339 | +                    error_message = error_line.replace("Error: ", "")  | 
 | 340 | +                else:  | 
 | 341 | +                    error_message = "Job failed (unknown reason - TODO)"  | 
 | 342 | +                raise Exception(error_message)  | 
 | 343 | +                return True  | 
 | 344 | + | 
 | 345 | +    async def _poll_job(self, jobid):  | 
 | 346 | +        cmd = ("oarstat", "-J", "-s", "-j", jobid)  | 
 | 347 | +        logger.debug(f"Polling job {jobid}")  | 
 | 348 | +        _, stdout, _ = await read_and_display_async(*cmd, hide_display=True)  | 
 | 349 | +        if not stdout:  | 
 | 350 | +            raise RuntimeError("Job information not found")  | 
 | 351 | +        status = json.loads(stdout)[jobid]  | 
 | 352 | +        if status in ["Waiting", "Launching", "Running", "Finishing"]:  | 
 | 353 | +            return False  | 
 | 354 | +        return status  | 
 | 355 | + | 
 | 356 | + | 
189 | 357 | class SlurmWorker(DistributedWorker):  | 
190 | 358 |     """A worker to execute tasks on SLURM systems."""  | 
191 | 359 | 
 
  | 
@@ -1042,6 +1210,7 @@ def close(self):  | 
1042 | 1210 |     "slurm": SlurmWorker,  | 
1043 | 1211 |     "dask": DaskWorker,  | 
1044 | 1212 |     "sge": SGEWorker,  | 
 | 1213 | +    "oar": OarWorker,  | 
1045 | 1214 |     **{  | 
1046 | 1215 |         "psij-" + subtype: lambda subtype=subtype: PsijWorker(subtype=subtype)  | 
1047 | 1216 |         for subtype in ["local", "slurm"]  | 
 | 
0 commit comments