Skip to content

Commit 1b8704b

Browse files
committed
feat(backend): Support namespaced TrainingRuntime in the SDK
1 parent e878505 commit 1b8704b

File tree

2 files changed

+62
-13
lines changed

2 files changed

+62
-13
lines changed

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,33 @@ def __init__(
6060
def list_runtimes(self) -> list[types.Runtime]:
6161
result = []
6262
try:
63-
thread = self.custom_api.list_cluster_custom_object(
63+
cluster_thread = self.custom_api.list_cluster_custom_object(
6464
constants.GROUP,
6565
constants.VERSION,
6666
constants.CLUSTER_TRAINING_RUNTIME_PLURAL,
6767
async_req=True,
6868
)
6969

70-
runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict(
71-
thread.get(constants.DEFAULT_TIMEOUT)
70+
namespace_thread = self.custom_api.list_namespaced_custom_object(
71+
constants.GROUP,
72+
constants.VERSION,
73+
self.namespace,
74+
constants.TRAINING_RUNTIME_PLURAL,
75+
async_req=True,
76+
)
77+
78+
cluster_runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict(
79+
cluster_thread.get(constants.DEFAULT_TIMEOUT)
80+
)
81+
82+
namespace_runtime_list = models.TrainerV1alpha1TrainingRuntimeList.from_dict(
83+
namespace_thread.get(constants.DEFAULT_TIMEOUT)
7284
)
7385

74-
if not runtime_list:
86+
if not (cluster_runtime_list and namespace_runtime_list):
7587
return result
7688

77-
for runtime in runtime_list.items:
89+
for runtime in namespace_runtime_list.items + cluster_runtime_list.items:
7890
if not (
7991
runtime.metadata
8092
and runtime.metadata.labels
@@ -89,33 +101,55 @@ def list_runtimes(self) -> list[types.Runtime]:
89101

90102
except multiprocessing.TimeoutError as e:
91103
raise TimeoutError(
92-
f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s "
104+
"Timeout to list "
105+
f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s "
93106
f"in namespace: {self.namespace}"
94107
) from e
95108
except Exception as e:
96109
raise RuntimeError(
97-
f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s "
110+
"Failed to list "
111+
f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s "
98112
f"in namespace: {self.namespace}"
99113
) from e
100114

101115
return result
102116

103117
def get_runtime(self, name: str) -> types.Runtime:
104-
"""Get the the Runtime object"""
118+
"""Get the the Runtime object prefer namespaced, fall-back to cluster-scoped"""
105119

106120
try:
107-
thread = self.custom_api.get_cluster_custom_object(
121+
cluster_thread = self.custom_api.get_cluster_custom_object(
108122
constants.GROUP,
109123
constants.VERSION,
110124
constants.CLUSTER_TRAINING_RUNTIME_PLURAL,
111125
name,
112126
async_req=True,
113127
)
114128

115-
runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict(
116-
thread.get(constants.DEFAULT_TIMEOUT) # type: ignore
129+
namespace_thread = self.custom_api.get_namespaced_custom_object(
130+
constants.GROUP,
131+
constants.VERSION,
132+
self.namespace,
133+
constants.TRAINING_RUNTIME_PLURAL,
134+
name,
135+
async_req=True,
117136
)
118137

138+
# Try namespaced runtime first, fall back to cluster-scoped one
139+
try:
140+
runtime = models.TrainerV1alpha1TrainingRuntime.from_dict(
141+
namespace_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore
142+
)
143+
except Exception as e:
144+
logger.warning(
145+
f"Namespaced TrainingRuntime '{self.namespace}/{name}' not found "
146+
f"({type(e).__name__}: {e}); falling back to cluster-scoped runtime."
147+
)
148+
149+
runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict(
150+
cluster_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore
151+
)
152+
119153
except multiprocessing.TimeoutError as e:
120154
raise TimeoutError(
121155
f"Timeout to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: "
@@ -433,8 +467,13 @@ def delete_job(self, name: str):
433467

434468
def __get_runtime_from_crd(
435469
self,
436-
runtime_crd: models.TrainerV1alpha1ClusterTrainingRuntime,
470+
runtime_crd: Union[
471+
models.TrainerV1alpha1ClusterTrainingRuntime, models.TrainerV1alpha1TrainingRuntime
472+
],
437473
) -> types.Runtime:
474+
crd_kind = getattr(runtime_crd, "kind", "UnknownKind")
475+
crd_name = getattr(runtime_crd.metadata, "name", "UnknownName")
476+
438477
if not (
439478
runtime_crd.metadata
440479
and runtime_crd.metadata.name
@@ -443,7 +482,11 @@ def __get_runtime_from_crd(
443482
and runtime_crd.spec.template.spec
444483
and runtime_crd.spec.template.spec.replicated_jobs
445484
):
446-
raise Exception(f"ClusterTrainingRuntime CRD is invalid: {runtime_crd}")
485+
raise Exception(
486+
f"{crd_kind} '{crd_name}' is invalid — missing one or more required fields: "
487+
f"metadata.name, spec.mlPolicy, spec.template.spec.replicatedJobs.\n"
488+
f"Full object: {runtime_crd}"
489+
)
447490

448491
if not (
449492
runtime_crd.metadata.labels

kubeflow/trainer/constants/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
# The plural for the ClusterTrainingRuntime.
3333
CLUSTER_TRAINING_RUNTIME_PLURAL = "clustertrainingruntimes"
3434

35+
# The Kind name for the TrainingRuntime.
36+
TRAINING_RUNTIME_KIND = "ClusterTrainingRuntime"
37+
38+
# The plural for the ClusterTrainingRuntime.
39+
TRAINING_RUNTIME_PLURAL = "trainingruntimes"
40+
3541
# The Kind name for the TrainJob.
3642
TRAINJOB_KIND = "TrainJob"
3743

0 commit comments

Comments
 (0)