1515import shlex
1616import subprocess
1717import tempfile
18+ import warnings
1819from dataclasses import dataclass
19- from typing import Any , Dict , List , Mapping , Optional , Tuple
20+ from datetime import datetime
21+ from typing import Any , Dict , List , Mapping , Optional , Tuple , Iterable
2022
21- from torchx .schedulers .api import AppDryRunInfo , DescribeAppResponse , Scheduler
23+ from torchx .schedulers .api import AppDryRunInfo , DescribeAppResponse , Scheduler , Stream
24+ from torchx .schedulers .local_scheduler import LogIterator
2225from torchx .specs import (
2326 NONE ,
2427 AppDef ,
@@ -100,26 +103,41 @@ def from_role(
100103 if resource .gpu > 0 :
101104 sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
102105
106+ srun_opts = {
107+ "output" : f"slurm-{ macros .app_id } -{ name } .out" ,
108+ }
109+
103110 return cls (
104111 name = name ,
105112 entrypoint = role .entrypoint ,
106113 args = list (role .args ),
107114 sbatch_opts = sbatch_opts ,
108- srun_opts = {} ,
115+ srun_opts = srun_opts ,
109116 env = dict (role .env ),
110117 )
111118
119+ def _opts_to_strs (self , opts : Dict [str , str ]) -> List [str ]:
120+ out = []
121+ for key , value in opts .items ():
122+ if value is not None :
123+ out .append (f"--{ key } ={ value } " )
124+ else :
125+ out .append (f"--{ key } " )
126+ return out
127+
112128 def materialize (self ) -> Tuple [List [str ], List [str ]]:
113129 """
114130 materialize returns the sbatch and srun groups for this role. They
115131 should be combined using `:` per slurm heterogenous groups.
116132 """
117133 sbatch_args = [
118134 f"--job-name={ self .name } " ,
119- ] + [f"--{ key } ={ value } " for key , value in self .sbatch_opts .items ()]
120- srun_args = [f"--{ key } ={ value } " for key , value in self .srun_opts .items ()] + [
121- f"--export={ key } ={ value } " for key , value in self .env .items ()
122- ]
135+ ] + self ._opts_to_strs (self .sbatch_opts )
136+ srun_args = self ._opts_to_strs (self .srun_opts )
137+
138+ if len (self .env ) > 0 :
139+ kvs = [f"{ key } ={ value } " for key , value in self .env .items ()]
140+ srun_args += ["--export=ALL," + "," .join (kvs )]
123141
124142 srun_group = srun_args + [self .entrypoint ] + self .args
125143 srun_group = [_apply_app_id_env (arg ) for arg in srun_group ]
@@ -160,6 +178,9 @@ def materialize(self) -> str:
160178# exit on error
161179set -e
162180
181+ export PYTHONUNBUFFERED=1
182+ export SLURM_UNBUFFEREDIO=1
183+
163184srun { " " .join (srun_groups )}
164185"""
165186 sbatch_cmd = self .cmd + sbatch_groups
@@ -176,7 +197,11 @@ class SlurmScheduler(Scheduler):
176197 resource allocations and args and then sbatch is used to launch all of them
177198 together.
178199
179- Logs are written to the default slurm log file.
200+ Logs are available in combined form via ``torchx log``, the programmatic API
201+ as well as in the job launch directory as
202+ ``slurm-<jobid>-<role>-<replica_id>.out``. If TorchX is running in a
203+ different directory than where the job was created the logs won't be able to
204+ be found.
180205
181206 Some of the config options passed to it are added as SBATCH arguments to each
182207 replica. See https://slurm.schedmd.com/sbatch.html#SECTION_OPTIONS for info
@@ -203,9 +228,7 @@ class SlurmScheduler(Scheduler):
203228 type: scheduler
204229 features:
205230 cancel: true
206- logs: |
207- Logs are accessible via the default slurm log file but not the
208- programmatic API.
231+ logs: true
209232 distributed: true
210233 describe: |
211234 Partial support. SlurmScheduler will return job and replica
@@ -262,7 +285,7 @@ def _submit_dryrun(
262285 app_id = macros .app_id ,
263286 replica_id = str (replica_id ),
264287 )
265- name = f"{ app . name } - { role .name } -{ replica_id } "
288+ name = f"{ role .name } -{ replica_id } "
266289 replica_role = values .apply (role )
267290 replicas [name ] = SlurmReplicaRequest .from_role (name , replica_role , cfg )
268291 req = SlurmBatchRequest (
@@ -286,6 +309,8 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
286309 if len (output ) <= 1 :
287310 return None
288311
312+ print (output )
313+
289314 reader = csv .DictReader (output , delimiter = "|" )
290315
291316 roles = {}
@@ -308,19 +333,19 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
308333 ), f"failed to translate slurm state { state } to torchx state"
309334 app_state = state_enum
310335
311- name_parts = row ["JobName" ].split ("-" )
312- if len ( name_parts ) < 3 :
336+ role , _ , replica_id = row ["JobName" ].rpartition ("-" )
337+ if not replica_id or not role :
313338 # name should always have at least 3 parts but sometimes sacct
314339 # is slow to update
315340 continue
316- role = name_parts [- 2 ]
317- replica_id = int (name_parts [- 1 ])
318341 if role not in roles :
319342 roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
320343 roles_statuses [role ] = RoleStatus (role , [])
321344 roles [role ].num_replicas += 1
322345 roles_statuses [role ].replicas .append (
323- ReplicaStatus (id = replica_id , role = role , state = app_state , hostname = "" ),
346+ ReplicaStatus (
347+ id = int (replica_id ), role = role , state = app_state , hostname = ""
348+ ),
324349 )
325350
326351 return DescribeAppResponse (
@@ -331,6 +356,34 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
331356 msg = msg ,
332357 )
333358
359+ def log_iter (
360+ self ,
361+ app_id : str ,
362+ role_name : str ,
363+ k : int = 0 ,
364+ regex : Optional [str ] = None ,
365+ since : Optional [datetime ] = None ,
366+ until : Optional [datetime ] = None ,
367+ should_tail : bool = False ,
368+ streams : Optional [Stream ] = None ,
369+ ) -> Iterable [str ]:
370+ if since or until :
371+ warnings .warn (
372+ "since and/or until times specified for SlurmScheduler.log_iter."
373+ " These will be ignored and all log lines will be returned"
374+ )
375+ if streams is not None and streams != Stream .COMBINED :
376+ warnings .warn (
377+ "streams specified for SlurmScheduler.log_iter."
378+ " These will be ignored and all log lines will be returned"
379+ )
380+
381+ log_file = f"slurm-{ app_id } -{ role_name } -{ k } .out"
382+
383+ return LogIterator (
384+ app_id , regex or ".*" , log_file , self , should_tail = should_tail
385+ )
386+
334387
335388def create_scheduler (session_name : str , ** kwargs : Any ) -> SlurmScheduler :
336389 return SlurmScheduler (
0 commit comments