Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Kubeflow Trainer client supports local development without needing a Kubernetes
### Available Backends

- **KubernetesBackend** (default) - Production training on Kubernetes
- **ContainerBackend** - Local development with Docker/Podman isolation
- **ContainerBackend** - Local development with Docker/Podman isolation
- **LocalProcessBackend** - Quick prototyping with Python subprocesses

**Quick Start:**
Expand Down
69 changes: 56 additions & 13 deletions kubeflow/trainer/backends/kubernetes/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,33 @@ def __init__(self, cfg: KubernetesBackendConfig):
def list_runtimes(self) -> list[types.Runtime]:
result = []
try:
thread = self.custom_api.list_cluster_custom_object(
cluster_thread = self.custom_api.list_cluster_custom_object(
constants.GROUP,
constants.VERSION,
constants.CLUSTER_TRAINING_RUNTIME_PLURAL,
async_req=True,
)

runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict(
thread.get(common_constants.DEFAULT_TIMEOUT)
namespace_thread = self.custom_api.list_namespaced_custom_object(
constants.GROUP,
constants.VERSION,
self.namespace,
constants.TRAINING_RUNTIME_PLURAL,
async_req=True,
)

cluster_runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict(
cluster_thread.get(constants.DEFAULT_TIMEOUT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cluster_thread.get(constants.DEFAULT_TIMEOUT)
cluster_thread.get(common_constants.DEFAULT_TIMEOUT)

)

namespace_runtime_list = models.TrainerV1alpha1TrainingRuntimeList.from_dict(
namespace_thread.get(constants.DEFAULT_TIMEOUT)
)

if not runtime_list:
if not (cluster_runtime_list or namespace_runtime_list):
return result

for runtime in runtime_list.items:
for runtime in namespace_runtime_list.items + cluster_runtime_list.items:
if not (
runtime.metadata
and runtime.metadata.labels
Expand All @@ -88,33 +100,55 @@ def list_runtimes(self) -> list[types.Runtime]:

except multiprocessing.TimeoutError as e:
raise TimeoutError(
f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s "
"Timeout to list "
f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s "
f"in namespace: {self.namespace}"
) from e
except Exception as e:
raise RuntimeError(
f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s "
"Failed to list "
f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s "
f"in namespace: {self.namespace}"
) from e

return result

def get_runtime(self, name: str) -> types.Runtime:
"""Get the the Runtime object"""
"""Get the the Runtime object prefer namespaced, fall-back to cluster-scoped"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Get the the Runtime object prefer namespaced, fall-back to cluster-scoped"""
"""Get the Runtime object prefer namespaced, fall-back to cluster-scoped"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same change goes for each occurence here


try:
thread = self.custom_api.get_cluster_custom_object(
cluster_thread = self.custom_api.get_cluster_custom_object(
constants.GROUP,
constants.VERSION,
constants.CLUSTER_TRAINING_RUNTIME_PLURAL,
name,
async_req=True,
)

runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict(
thread.get(common_constants.DEFAULT_TIMEOUT) # type: ignore
namespace_thread = self.custom_api.get_namespaced_custom_object(
constants.GROUP,
constants.VERSION,
self.namespace,
constants.TRAINING_RUNTIME_PLURAL,
name,
async_req=True,
)

# Try namespaced runtime first, fall back to cluster-scoped one
try:
runtime = models.TrainerV1alpha1TrainingRuntime.from_dict(
namespace_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore
)
except Exception as e:
logger.warning(
f"Namespaced TrainingRuntime '{self.namespace}/{name}' not found "
f"({type(e).__name__}: {e}); falling back to cluster-scoped runtime."
)

runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict(
cluster_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore
)

except multiprocessing.TimeoutError as e:
raise TimeoutError(
f"Timeout to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: "
Expand Down Expand Up @@ -396,8 +430,13 @@ def delete_job(self, name: str):

def __get_runtime_from_cr(
self,
runtime_cr: models.TrainerV1alpha1ClusterTrainingRuntime,
runtime_cr: Union[
models.TrainerV1alpha1ClusterTrainingRuntime, models.TrainerV1alpha1TrainingRuntime
],
) -> types.Runtime:
crd_kind = getattr(runtime_cr, "kind", "UnknownKind")
crd_name = getattr(runtime_cr.metadata, "name", "UnknownName")

if not (
runtime_cr.metadata
and runtime_cr.metadata.name
Expand All @@ -406,7 +445,11 @@ def __get_runtime_from_cr(
and runtime_cr.spec.template.spec
and runtime_cr.spec.template.spec.replicated_jobs
):
raise Exception(f"ClusterTrainingRuntime CR is invalid: {runtime_cr}")
raise Exception(
f"{crd_kind} '{crd_name}' is invalid — missing one or more required fields: "
f"metadata.name, spec.mlPolicy, spec.template.spec.replicatedJobs.\n"
f"Full object: {runtime_cr}"
)

if not (
runtime_cr.metadata.labels
Expand Down
41 changes: 41 additions & 0 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,16 @@ def list_namespaced_custom_object_response(*args, **kwargs):
models.TrainerV1alpha1TrainJobList(items=items),
models.TrainerV1alpha1TrainJobList,
)
elif args[3] == constants.TRAINING_RUNTIME_PLURAL:
# TODO: add test case for namespace scoped runtimes
# items = [
# create_training_runtime(name="runtime-1"),
# create_training_runtime(name="runtime-2"),
# ]
mock_thread.get.return_value = normalize_model(
models.TrainerV1alpha1TrainingRuntimeList(items=[]),
models.TrainerV1alpha1TrainingRuntimeList,
)

return mock_thread

Expand Down Expand Up @@ -490,6 +500,37 @@ def create_cluster_training_runtime(
)


def create_training_runtime(
name: str,
namespace: str = "default",
) -> models.TrainerV1alpha1TrainingRuntime:
"""Create a mock namespaced TrainingRuntime object (not cluster-scoped)."""
return models.TrainerV1alpha1TrainingRuntime(
apiVersion=constants.API_VERSION,
kind="TrainingRuntime",
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
name=name,
namespace=namespace,
labels={constants.RUNTIME_FRAMEWORK_LABEL: name},
),
spec=models.TrainerV1alpha1TrainingRuntimeSpec(
mlPolicy=models.TrainerV1alpha1MLPolicy(
torch=models.TrainerV1alpha1TorchMLPolicySource(
numProcPerNode=models.IoK8sApimachineryPkgUtilIntstrIntOrString(2)
),
numNodes=2,
),
template=models.TrainerV1alpha1JobSetTemplateSpec(
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
name=name,
namespace=namespace,
),
spec=models.JobsetV1alpha2JobSetSpec(replicatedJobs=[get_replicated_job()]),
),
),
)


Comment on lines +503 to +533
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean to create this in kubernetes/backend_test.py?
this is not a test function and I believe it should be added to the TrainerClient and propagated to the different backends.

def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob:
return models.JobsetV1alpha2ReplicatedJob(
name="node",
Expand Down
6 changes: 6 additions & 0 deletions kubeflow/trainer/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
# The plural for the ClusterTrainingRuntime.
CLUSTER_TRAINING_RUNTIME_PLURAL = "clustertrainingruntimes"

# The Kind name for the TrainingRuntime.
TRAINING_RUNTIME_KIND = "TrainingRuntime"

# The plural for the ClusterTrainingRuntime.
TRAINING_RUNTIME_PLURAL = "trainingruntimes"

# The Kind name for the TrainJob.
TRAINJOB_KIND = "TrainJob"

Expand Down
Loading