-
Notifications
You must be signed in to change notification settings - Fork 45
feat(trainer): Support namespaced TrainingRuntime in the SDK #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -59,21 +59,33 @@ def __init__(self, cfg: KubernetesBackendConfig): | |||||
| def list_runtimes(self) -> list[types.Runtime]: | ||||||
| result = [] | ||||||
| try: | ||||||
| thread = self.custom_api.list_cluster_custom_object( | ||||||
| cluster_thread = self.custom_api.list_cluster_custom_object( | ||||||
| constants.GROUP, | ||||||
| constants.VERSION, | ||||||
| constants.CLUSTER_TRAINING_RUNTIME_PLURAL, | ||||||
| async_req=True, | ||||||
| ) | ||||||
|
|
||||||
| runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict( | ||||||
| thread.get(common_constants.DEFAULT_TIMEOUT) | ||||||
| namespace_thread = self.custom_api.list_namespaced_custom_object( | ||||||
| constants.GROUP, | ||||||
| constants.VERSION, | ||||||
| self.namespace, | ||||||
| constants.TRAINING_RUNTIME_PLURAL, | ||||||
| async_req=True, | ||||||
| ) | ||||||
|
|
||||||
| cluster_runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict( | ||||||
| cluster_thread.get(constants.DEFAULT_TIMEOUT) | ||||||
| ) | ||||||
|
|
||||||
| namespace_runtime_list = models.TrainerV1alpha1TrainingRuntimeList.from_dict( | ||||||
| namespace_thread.get(constants.DEFAULT_TIMEOUT) | ||||||
| ) | ||||||
|
|
||||||
| if not runtime_list: | ||||||
| if not (cluster_runtime_list or namespace_runtime_list): | ||||||
| return result | ||||||
|
|
||||||
| for runtime in runtime_list.items: | ||||||
| for runtime in namespace_runtime_list.items + cluster_runtime_list.items: | ||||||
| if not ( | ||||||
| runtime.metadata | ||||||
| and runtime.metadata.labels | ||||||
|
|
@@ -88,33 +100,55 @@ def list_runtimes(self) -> list[types.Runtime]: | |||||
|
|
||||||
| except multiprocessing.TimeoutError as e: | ||||||
| raise TimeoutError( | ||||||
| f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s " | ||||||
| "Timeout to list " | ||||||
| f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s " | ||||||
| f"in namespace: {self.namespace}" | ||||||
| ) from e | ||||||
| except Exception as e: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s " | ||||||
| "Failed to list " | ||||||
| f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s " | ||||||
| f"in namespace: {self.namespace}" | ||||||
| ) from e | ||||||
|
|
||||||
| return result | ||||||
|
|
||||||
| def get_runtime(self, name: str) -> types.Runtime: | ||||||
| """Get the the Runtime object""" | ||||||
| """Get the the Runtime object prefer namespaced, fall-back to cluster-scoped""" | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same change goes for each occurence here |
||||||
|
|
||||||
| try: | ||||||
| thread = self.custom_api.get_cluster_custom_object( | ||||||
| cluster_thread = self.custom_api.get_cluster_custom_object( | ||||||
| constants.GROUP, | ||||||
| constants.VERSION, | ||||||
| constants.CLUSTER_TRAINING_RUNTIME_PLURAL, | ||||||
| name, | ||||||
| async_req=True, | ||||||
| ) | ||||||
|
|
||||||
| runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict( | ||||||
| thread.get(common_constants.DEFAULT_TIMEOUT) # type: ignore | ||||||
| namespace_thread = self.custom_api.get_namespaced_custom_object( | ||||||
| constants.GROUP, | ||||||
| constants.VERSION, | ||||||
| self.namespace, | ||||||
| constants.TRAINING_RUNTIME_PLURAL, | ||||||
| name, | ||||||
| async_req=True, | ||||||
| ) | ||||||
|
|
||||||
| # Try namespaced runtime first, fall back to cluster-scoped one | ||||||
| try: | ||||||
| runtime = models.TrainerV1alpha1TrainingRuntime.from_dict( | ||||||
| namespace_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore | ||||||
| ) | ||||||
| except Exception as e: | ||||||
| logger.warning( | ||||||
| f"Namespaced TrainingRuntime '{self.namespace}/{name}' not found " | ||||||
| f"({type(e).__name__}: {e}); falling back to cluster-scoped runtime." | ||||||
| ) | ||||||
|
|
||||||
| runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict( | ||||||
| cluster_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore | ||||||
| ) | ||||||
|
|
||||||
| except multiprocessing.TimeoutError as e: | ||||||
| raise TimeoutError( | ||||||
| f"Timeout to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: " | ||||||
|
|
@@ -396,8 +430,13 @@ def delete_job(self, name: str): | |||||
|
|
||||||
| def __get_runtime_from_cr( | ||||||
| self, | ||||||
| runtime_cr: models.TrainerV1alpha1ClusterTrainingRuntime, | ||||||
| runtime_cr: Union[ | ||||||
| models.TrainerV1alpha1ClusterTrainingRuntime, models.TrainerV1alpha1TrainingRuntime | ||||||
| ], | ||||||
| ) -> types.Runtime: | ||||||
| crd_kind = getattr(runtime_cr, "kind", "UnknownKind") | ||||||
| crd_name = getattr(runtime_cr.metadata, "name", "UnknownName") | ||||||
|
|
||||||
| if not ( | ||||||
| runtime_cr.metadata | ||||||
| and runtime_cr.metadata.name | ||||||
|
|
@@ -406,7 +445,11 @@ def __get_runtime_from_cr( | |||||
| and runtime_cr.spec.template.spec | ||||||
| and runtime_cr.spec.template.spec.replicated_jobs | ||||||
| ): | ||||||
| raise Exception(f"ClusterTrainingRuntime CR is invalid: {runtime_cr}") | ||||||
| raise Exception( | ||||||
| f"{crd_kind} '{crd_name}' is invalid — missing one or more required fields: " | ||||||
| f"metadata.name, spec.mlPolicy, spec.template.spec.replicatedJobs.\n" | ||||||
| f"Full object: {runtime_cr}" | ||||||
| ) | ||||||
|
|
||||||
| if not ( | ||||||
| runtime_cr.metadata.labels | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -357,6 +357,16 @@ def list_namespaced_custom_object_response(*args, **kwargs): | |
| models.TrainerV1alpha1TrainJobList(items=items), | ||
| models.TrainerV1alpha1TrainJobList, | ||
| ) | ||
| elif args[3] == constants.TRAINING_RUNTIME_PLURAL: | ||
| # TODO: add test case for namespace scoped runtimes | ||
| # items = [ | ||
| # create_training_runtime(name="runtime-1"), | ||
| # create_training_runtime(name="runtime-2"), | ||
| # ] | ||
| mock_thread.get.return_value = normalize_model( | ||
| models.TrainerV1alpha1TrainingRuntimeList(items=[]), | ||
| models.TrainerV1alpha1TrainingRuntimeList, | ||
| ) | ||
|
|
||
| return mock_thread | ||
|
|
||
|
|
@@ -490,6 +500,37 @@ def create_cluster_training_runtime( | |
| ) | ||
|
|
||
|
|
||
| def create_training_runtime( | ||
| name: str, | ||
| namespace: str = "default", | ||
| ) -> models.TrainerV1alpha1TrainingRuntime: | ||
| """Create a mock namespaced TrainingRuntime object (not cluster-scoped).""" | ||
| return models.TrainerV1alpha1TrainingRuntime( | ||
| apiVersion=constants.API_VERSION, | ||
| kind="TrainingRuntime", | ||
| metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( | ||
| name=name, | ||
| namespace=namespace, | ||
| labels={constants.RUNTIME_FRAMEWORK_LABEL: name}, | ||
| ), | ||
| spec=models.TrainerV1alpha1TrainingRuntimeSpec( | ||
| mlPolicy=models.TrainerV1alpha1MLPolicy( | ||
| torch=models.TrainerV1alpha1TorchMLPolicySource( | ||
| numProcPerNode=models.IoK8sApimachineryPkgUtilIntstrIntOrString(2) | ||
| ), | ||
| numNodes=2, | ||
| ), | ||
| template=models.TrainerV1alpha1JobSetTemplateSpec( | ||
| metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( | ||
| name=name, | ||
| namespace=namespace, | ||
| ), | ||
| spec=models.JobsetV1alpha2JobSetSpec(replicatedJobs=[get_replicated_job()]), | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
|
Comment on lines
+503
to
+533
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you mean to create this in |
||
| def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob: | ||
| return models.JobsetV1alpha2ReplicatedJob( | ||
| name="node", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.