Skip to content

Commit f342acd

Browse files
chore(trainer): Use explicit exception chaining (#80)
* chore(trainer): Use explicit exception chaining Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Update kubeflow/trainer/api/trainer_client.py Co-authored-by: Anya Kramar <akramar@redhat.com> Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Use builtin types in all files Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> --------- Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Co-authored-by: Anya Kramar <akramar@redhat.com>
1 parent 3bd3cd9 commit f342acd

File tree

6 files changed

+116
-87
lines changed

6 files changed

+116
-87
lines changed

kubeflow/trainer/api/trainer_client.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Dict, List, Optional, Set, Union
16+
from typing import Optional, Union
1717

1818
from kubeflow.trainer.constants import constants
1919
from kubeflow.trainer.types import types
@@ -35,42 +35,45 @@ def __init__(
3535
backend_config: Backend configuration. Either KubernetesBackendConfig or
3636
LocalProcessBackendConfig, or None to use the backend's
3737
default config class. Defaults to KubernetesBackendConfig.
38+
39+
Raises:
40+
ValueError: Invalid backend configuration.
41+
3842
"""
3943
# initialize training backend
4044
if isinstance(backend_config, KubernetesBackendConfig):
4145
self.backend = KubernetesBackend(backend_config)
4246
else:
4347
raise ValueError("Invalid backend config '{}'".format(backend_config))
4448

45-
def list_runtimes(self) -> types.Runtime:
46-
"""List of the available Runtimes.
49+
def list_runtimes(self) -> list[types.Runtime]:
50+
"""List of the available runtimes.
4751
4852
Returns:
49-
List[Runtime]: List of available training runtimes.
50-
If no runtimes exist, an empty list is returned.
53+
A list of available training runtimes. If no runtimes exist, an empty list is returned.
5154
5255
Raises:
53-
TimeoutError: Timeout to list Runtimes.
54-
RuntimeError: Failed to list Runtimes.
56+
TimeoutError: Timeout to list runtimes.
57+
RuntimeError: Failed to list runtimes.
5558
"""
5659
return self.backend.list_runtimes()
5760

5861
def get_runtime(self, name: str) -> types.Runtime:
59-
"""Get the Runtime object
62+
"""Get the runtime object
6063
Args:
6164
name: Name of the runtime.
62-
Returns:
63-
types.TrainingRuntime: Runtime object.
65+
66+
Returns:
67+
A runtime object.
6468
"""
6569
return self.backend.get_runtime(name=name)
6670

6771
def get_runtime_packages(self, runtime: types.Runtime):
68-
"""
69-
Print the installed Python packages for the given Runtime. If Runtime has GPUs it also
72+
"""Print the installed Python packages for the given runtime. If a runtime has GPUs it also
7073
prints available GPUs on the single training node.
7174
7275
Args:
73-
runtime: Reference to one of existing Runtimes.
76+
runtime: Reference to one of existing runtimes.
7477
7578
Raises:
7679
ValueError: Input arguments are invalid.
@@ -81,33 +84,40 @@ def get_runtime_packages(self, runtime: types.Runtime):
8184

8285
def train(
8386
self,
84-
runtime: types.Runtime = None,
87+
runtime: Optional[types.Runtime] = None,
8588
initializer: Optional[types.Initializer] = None,
8689
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
8790
) -> str:
88-
"""
89-
Create the TrainJob. You can configure these types of training task:
90-
- Custom Training Task: Training with a self-contained function that encapsulates
91-
the entire model training process, e.g. `CustomTrainer`.
92-
- Config-driven Task with Existing Trainer: Training with a trainer that already includes
93-
the post-training logic, requiring only parameter adjustments, e.g. `BuiltinTrainer`.
91+
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:
92+
93+
- CustomTrainer: Runs training with a user-defined function that fully encapsulates the
94+
training process.
95+
- BuiltinTrainer: Uses a predefined trainer with built-in post-training logic, requiring
96+
only parameter configuration.
97+
9498
Args:
95-
runtime: Reference to one of existing Runtimes.
96-
initializer:
97-
Configuration for the dataset and model initializers.
98-
trainer:
99-
Configuration for Custom Training Task or Config-driven Task with Builtin Trainer.
99+
runtime: Optional reference to one of the existing runtimes. Defaults to the
100+
torch-distributed runtime if not provided.
101+
initializer: Optional configuration for the dataset and model initializers.
102+
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
103+
the TrainJob will use the runtime's default values.
104+
100105
Returns:
101-
str: The unique name of the TrainJob that has been generated.
106+
The unique name of the TrainJob that has been generated.
107+
102108
Raises:
103109
ValueError: Input arguments are invalid.
104110
TimeoutError: Timeout to create TrainJobs.
105111
RuntimeError: Failed to create TrainJobs.
106112
"""
107113
return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer)
108114

109-
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.TrainJob]:
110-
"""List of all TrainJobs.
115+
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
116+
"""List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with
117+
that runtime are returned.
118+
119+
Args:
120+
runtime: Reference to one of the existing runtimes.
111121
112122
Returns:
113123
List: List of created TrainJobs.
@@ -120,7 +130,19 @@ def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.Train
120130
return self.backend.list_jobs(runtime=runtime)
121131

122132
def get_job(self, name: str) -> types.TrainJob:
123-
"""Get the TrainJob object"""
133+
"""Get the TrainJob object
134+
135+
Args:
136+
name: Name of the TrainJob.
137+
138+
Returns:
139+
A TrainJob object.
140+
141+
Raises:
142+
TimeoutError: Timeout to get a TrainJob.
143+
RuntimeError: Failed to get a TrainJob.
144+
"""
145+
124146
return self.backend.get_job(name=name)
125147

126148
def get_job_logs(
@@ -129,28 +151,29 @@ def get_job_logs(
129151
follow: Optional[bool] = False,
130152
step: str = constants.NODE,
131153
node_rank: int = 0,
132-
) -> Dict[str, str]:
154+
) -> dict[str, str]:
133155
"""Get the logs from TrainJob"""
134156
return self.backend.get_job_logs(name=name, follow=follow, step=step, node_rank=node_rank)
135157

136158
def wait_for_job_status(
137159
self,
138160
name: str,
139-
status: Set[str] = {constants.TRAINJOB_COMPLETE},
161+
status: set[str] = {constants.TRAINJOB_COMPLETE},
140162
timeout: int = 600,
141163
polling_interval: int = 2,
142164
) -> types.TrainJob:
143-
"""Wait for TrainJob to reach the desired status
165+
"""Wait for a TrainJob to reach a desired status.
144166
145167
Args:
146168
name: Name of the TrainJob.
147-
status: Set of expected statuses. It must be subset of Created, Running, Complete, and
169+
status: Expected statuses. Must be a subset of Created, Running, Complete, and
148170
Failed statuses.
149-
timeout: How many seconds to wait until TrainJob reaches one of the expected conditions.
171+
timeout: Maximum number of seconds to wait for the TrainJob to reach one of the
172+
expected statuses.
150173
polling_interval: The polling interval in seconds to check TrainJob status.
151174
152175
Returns:
153-
TrainJob: The training job that reaches the desired status.
176+
A TrainJob object that reaches the desired status.
154177
155178
Raises:
156179
ValueError: The input values are incorrect.

kubeflow/trainer/backends/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
import abc
1616

17-
from typing import Dict, List, Optional, Set, Union
17+
from typing import Optional, Union
1818
from kubeflow.trainer.constants import constants
1919
from kubeflow.trainer.types import types
2020

2121

2222
class ExecutionBackend(abc.ABC):
23-
def list_runtimes(self) -> List[types.Runtime]:
23+
def list_runtimes(self) -> list[types.Runtime]:
2424
raise NotImplementedError()
2525

2626
def get_runtime(self, name: str) -> types.Runtime:
@@ -37,7 +37,7 @@ def train(
3737
) -> str:
3838
raise NotImplementedError()
3939

40-
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.TrainJob]:
40+
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
4141
raise NotImplementedError()
4242

4343
def get_job(self, name: str) -> types.TrainJob:
@@ -49,13 +49,13 @@ def get_job_logs(
4949
follow: Optional[bool] = False,
5050
step: str = constants.NODE,
5151
node_rank: int = 0,
52-
) -> Dict[str, str]:
52+
) -> dict[str, str]:
5353
raise NotImplementedError()
5454

5555
def wait_for_job_status(
5656
self,
5757
name: str,
58-
status: Set[str] = {constants.TRAINJOB_COMPLETE},
58+
status: set[str] = {constants.TRAINJOB_COMPLETE},
5959
timeout: int = 600,
6060
polling_interval: int = 2,
6161
) -> types.TrainJob:

0 commit comments

Comments
 (0)