1313# limitations under the License.
1414
1515import logging
16- from typing import Dict , List , Optional , Set , Union
16+ from typing import Optional , Union
1717
1818from kubeflow .trainer .constants import constants
1919from kubeflow .trainer .types import types
@@ -35,42 +35,45 @@ def __init__(
3535 backend_config: Backend configuration. Either KubernetesBackendConfig or
3636 LocalProcessBackendConfig, or None to use the backend's
3737 default config class. Defaults to KubernetesBackendConfig.
38+
39+ Raises:
40+ ValueError: Invalid backend configuration.
41+
3842 """
3943 # initialize training backend
4044 if isinstance (backend_config , KubernetesBackendConfig ):
4145 self .backend = KubernetesBackend (backend_config )
4246 else :
4347 raise ValueError ("Invalid backend config '{}'" .format (backend_config ))
4448
45- def list_runtimes (self ) -> types .Runtime :
46- """List of the available Runtimes .
49+ def list_runtimes (self ) -> list [ types .Runtime ] :
50+ """List of the available runtimes .
4751
4852 Returns:
49- List[Runtime]: List of available training runtimes.
50- If no runtimes exist, an empty list is returned.
53+ A list of available training runtimes. If no runtimes exist, an empty list is returned.
5154
5255 Raises:
53- TimeoutError: Timeout to list Runtimes .
54- RuntimeError: Failed to list Runtimes .
56+ TimeoutError: Timeout to list runtimes .
57+ RuntimeError: Failed to list runtimes .
5558 """
5659 return self .backend .list_runtimes ()
5760
5861 def get_runtime (self , name : str ) -> types .Runtime :
59- """Get the Runtime object
62+ """Get the runtime object
6063 Args:
6164 name: Name of the runtime.
62- Returns:
63- types.TrainingRuntime: Runtime object.
65+
66+ Returns:
67+ A runtime object.
6468 """
6569 return self .backend .get_runtime (name = name )
6670
6771 def get_runtime_packages (self , runtime : types .Runtime ):
68- """
69- Print the installed Python packages for the given Runtime. If Runtime has GPUs it also
72+ """Print the installed Python packages for the given runtime. If a runtime has GPUs it also
7073 prints available GPUs on the single training node.
7174
7275 Args:
73- runtime: Reference to one of existing Runtimes .
76+ runtime: Reference to one of existing runtimes .
7477
7578 Raises:
7679 ValueError: Input arguments are invalid.
@@ -81,33 +84,40 @@ def get_runtime_packages(self, runtime: types.Runtime):
8184
8285 def train (
8386 self ,
84- runtime : types .Runtime = None ,
87+ runtime : Optional [ types .Runtime ] = None ,
8588 initializer : Optional [types .Initializer ] = None ,
8689 trainer : Optional [Union [types .CustomTrainer , types .BuiltinTrainer ]] = None ,
8790 ) -> str :
88- """
89- Create the TrainJob. You can configure these types of training task:
90- - Custom Training Task: Training with a self-contained function that encapsulates
91- the entire model training process, e.g. `CustomTrainer`.
92- - Config-driven Task with Existing Trainer: Training with a trainer that already includes
93- the post-training logic, requiring only parameter adjustments, e.g. `BuiltinTrainer`.
91+ """Create a TrainJob. You can configure the TrainJob using one of these trainers:
92+
93+ - CustomTrainer: Runs training with a user-defined function that fully encapsulates the
94+ training process.
95+ - BuiltinTrainer: Uses a predefined trainer with built-in post-training logic, requiring
96+ only parameter configuration.
97+
9498 Args:
95- runtime: Reference to one of existing Runtimes.
96- initializer:
97- Configuration for the dataset and model initializers.
98- trainer:
99- Configuration for Custom Training Task or Config-driven Task with Builtin Trainer.
99+ runtime: Optional reference to one of the existing runtimes. Defaults to the
100+ torch-distributed runtime if not provided.
101+ initializer: Optional configuration for the dataset and model initializers.
102+ trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
103+ the TrainJob will use the runtime's default values.
104+
100105 Returns:
101- str: The unique name of the TrainJob that has been generated.
106+ The unique name of the TrainJob that has been generated.
107+
102108 Raises:
103109 ValueError: Input arguments are invalid.
104110 TimeoutError: Timeout to create TrainJobs.
105111 RuntimeError: Failed to create TrainJobs.
106112 """
107113 return self .backend .train (runtime = runtime , initializer = initializer , trainer = trainer )
108114
109- def list_jobs (self , runtime : Optional [types .Runtime ] = None ) -> List [types .TrainJob ]:
110- """List of all TrainJobs.
115+ def list_jobs (self , runtime : Optional [types .Runtime ] = None ) -> list [types .TrainJob ]:
116+ """List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with
117+ that runtime are returned.
118+
119+ Args:
120+ runtime: Reference to one of the existing runtimes.
111121
112122 Returns:
113123 List: List of created TrainJobs.
@@ -120,7 +130,19 @@ def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.Train
120130 return self .backend .list_jobs (runtime = runtime )
121131
122132 def get_job (self , name : str ) -> types .TrainJob :
123- """Get the TrainJob object"""
133+ """Get the TrainJob object
134+
135+ Args:
136+ name: Name of the TrainJob.
137+
138+ Returns:
139+ A TrainJob object.
140+
141+ Raises:
142+ TimeoutError: Timeout to get a TrainJob.
143+ RuntimeError: Failed to get a TrainJob.
144+ """
145+
124146 return self .backend .get_job (name = name )
125147
126148 def get_job_logs (
@@ -129,28 +151,29 @@ def get_job_logs(
129151 follow : Optional [bool ] = False ,
130152 step : str = constants .NODE ,
131153 node_rank : int = 0 ,
132- ) -> Dict [str , str ]:
154+ ) -> dict [str , str ]:
133155 """Get the logs from TrainJob"""
134156 return self .backend .get_job_logs (name = name , follow = follow , step = step , node_rank = node_rank )
135157
136158 def wait_for_job_status (
137159 self ,
138160 name : str ,
139- status : Set [str ] = {constants .TRAINJOB_COMPLETE },
161+ status : set [str ] = {constants .TRAINJOB_COMPLETE },
140162 timeout : int = 600 ,
141163 polling_interval : int = 2 ,
142164 ) -> types .TrainJob :
143- """Wait for TrainJob to reach the desired status
165+ """Wait for a TrainJob to reach a desired status.
144166
145167 Args:
146168 name: Name of the TrainJob.
147- status: Set of expected statuses. It must be subset of Created, Running, Complete, and
169+ status: Expected statuses. Must be a subset of Created, Running, Complete, and
148170 Failed statuses.
149- timeout: How many seconds to wait until TrainJob reaches one of the expected conditions.
171+ timeout: Maximum number of seconds to wait for the TrainJob to reach one of the
172+ expected statuses.
150173 polling_interval: The polling interval in seconds to check TrainJob status.
151174
152175 Returns:
153- TrainJob: The training job that reaches the desired status.
176+ A TrainJob object that reaches the desired status.
154177
155178 Raises:
156179 ValueError: The input values are incorrect.
0 commit comments