@@ -60,21 +60,33 @@ def __init__(
6060 def list_runtimes (self ) -> list [types .Runtime ]:
6161 result = []
6262 try :
63- thread = self .custom_api .list_cluster_custom_object (
63+ cluster_thread = self .custom_api .list_cluster_custom_object (
6464 constants .GROUP ,
6565 constants .VERSION ,
6666 constants .CLUSTER_TRAINING_RUNTIME_PLURAL ,
6767 async_req = True ,
6868 )
6969
70- runtime_list = models .TrainerV1alpha1ClusterTrainingRuntimeList .from_dict (
71- thread .get (constants .DEFAULT_TIMEOUT )
70+ namespace_thread = self .custom_api .list_namespaced_custom_object (
71+ constants .GROUP ,
72+ constants .VERSION ,
73+ self .namespace ,
74+ constants .TRAINING_RUNTIME_PLURAL ,
75+ async_req = True ,
76+ )
77+
78+ cluster_runtime_list = models .TrainerV1alpha1ClusterTrainingRuntimeList .from_dict (
79+ cluster_thread .get (constants .DEFAULT_TIMEOUT )
80+ )
81+
82+ namespace_runtime_list = models .TrainerV1alpha1TrainingRuntimeList .from_dict (
83+ namespace_thread .get (constants .DEFAULT_TIMEOUT )
7284 )
7385
74- if not runtime_list :
86+ if not ( cluster_runtime_list and namespace_runtime_list ) :
7587 return result
7688
77- for runtime in runtime_list .items :
89+ for runtime in namespace_runtime_list . items + cluster_runtime_list .items :
7890 if not (
7991 runtime .metadata
8092 and runtime .metadata .labels
@@ -89,33 +101,55 @@ def list_runtimes(self) -> list[types.Runtime]:
89101
90102 except multiprocessing .TimeoutError as e :
91103 raise TimeoutError (
92- f"Timeout to list { constants .CLUSTER_TRAINING_RUNTIME_KIND } s "
104+ "Timeout to list "
105+ f"{ constants .CLUSTER_TRAINING_RUNTIME_KIND } s/{ constants .TRAINING_RUNTIME_KIND } s "
93106 f"in namespace: { self .namespace } "
94107 ) from e
95108 except Exception as e :
96109 raise RuntimeError (
97- f"Failed to list { constants .CLUSTER_TRAINING_RUNTIME_KIND } s "
110+ "Failed to list "
111+ f"{ constants .CLUSTER_TRAINING_RUNTIME_KIND } s/{ constants .TRAINING_RUNTIME_KIND } s "
98112 f"in namespace: { self .namespace } "
99113 ) from e
100114
101115 return result
102116
103117 def get_runtime (self , name : str ) -> types .Runtime :
104- """Get the the Runtime object"""
118+ """Get the the Runtime object prefer namespaced, fall-back to cluster-scoped """
105119
106120 try :
107- thread = self .custom_api .get_cluster_custom_object (
121+ cluster_thread = self .custom_api .get_cluster_custom_object (
108122 constants .GROUP ,
109123 constants .VERSION ,
110124 constants .CLUSTER_TRAINING_RUNTIME_PLURAL ,
111125 name ,
112126 async_req = True ,
113127 )
114128
115- runtime = models .TrainerV1alpha1ClusterTrainingRuntime .from_dict (
116- thread .get (constants .DEFAULT_TIMEOUT ) # type: ignore
129+ namespace_thread = self .custom_api .get_namespaced_custom_object (
130+ constants .GROUP ,
131+ constants .VERSION ,
132+ self .namespace ,
133+ constants .TRAINING_RUNTIME_PLURAL ,
134+ name ,
135+ async_req = True ,
117136 )
118137
138+ # Try namespaced runtime first, fall back to cluster-scoped one
139+ try :
140+ runtime = models .TrainerV1alpha1TrainingRuntime .from_dict (
141+ namespace_thread .get (constants .DEFAULT_TIMEOUT ) # type: ignore
142+ )
143+ except Exception as e :
144+ logger .warning (
145+ f"Namespaced TrainingRuntime '{ self .namespace } /{ name } ' not found "
146+ f"({ type (e ).__name__ } : { e } ); falling back to cluster-scoped runtime."
147+ )
148+
149+ runtime = models .TrainerV1alpha1ClusterTrainingRuntime .from_dict (
150+ cluster_thread .get (constants .DEFAULT_TIMEOUT ) # type: ignore
151+ )
152+
119153 except multiprocessing .TimeoutError as e :
120154 raise TimeoutError (
121155 f"Timeout to get { constants .CLUSTER_TRAINING_RUNTIME_PLURAL } : "
@@ -433,8 +467,13 @@ def delete_job(self, name: str):
433467
434468 def __get_runtime_from_crd (
435469 self ,
436- runtime_crd : models .TrainerV1alpha1ClusterTrainingRuntime ,
470+ runtime_crd : Union [
471+ models .TrainerV1alpha1ClusterTrainingRuntime , models .TrainerV1alpha1TrainingRuntime
472+ ],
437473 ) -> types .Runtime :
474+ crd_kind = getattr (runtime_crd , "kind" , "UnknownKind" )
475+ crd_name = getattr (runtime_crd .metadata , "name" , "UnknownName" )
476+
438477 if not (
439478 runtime_crd .metadata
440479 and runtime_crd .metadata .name
@@ -443,7 +482,11 @@ def __get_runtime_from_crd(
443482 and runtime_crd .spec .template .spec
444483 and runtime_crd .spec .template .spec .replicated_jobs
445484 ):
446- raise Exception (f"ClusterTrainingRuntime CRD is invalid: { runtime_crd } " )
485+ raise Exception (
486+ f"{ crd_kind } '{ crd_name } ' is invalid — missing one or more required fields: "
487+ f"metadata.name, spec.mlPolicy, spec.template.spec.replicatedJobs.\n "
488+ f"Full object: { runtime_crd } "
489+ )
447490
448491 if not (
449492 runtime_crd .metadata .labels
0 commit comments