Skip to content

Commit 8552865

Browse files
committed
feat(trainer): Introduce LocalTrainerClient
1 parent 8a1dea0 commit 8552865

File tree

12 files changed

+749
-32
lines changed

12 files changed

+749
-32
lines changed

python/kubeflow/trainer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
# Import the Kubeflow Trainer client.
2121
from kubeflow.trainer.api.trainer_client import TrainerClient
2222

23+
# Import the Kubeflow Local Trainer client.
24+
from kubeflow.trainer.api.local_trainer_client import LocalTrainerClient
25+
2326
# Import the Kubeflow Trainer constants.
2427
from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH
2528

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from typing import Dict, List, Optional
17+
18+
from kubeflow.trainer.constants import constants
19+
from kubeflow.trainer.types import types
20+
21+
22+
class AbstractTrainerClient(ABC):
23+
@abstractmethod
24+
def delete_job(self, name: str):
25+
pass
26+
27+
@abstractmethod
28+
def get_job(self, name: str) -> types.TrainJob:
29+
pass
30+
31+
@abstractmethod
32+
def get_job_logs(
33+
self,
34+
name: str,
35+
follow: Optional[bool] = False,
36+
step: str = constants.NODE,
37+
node_rank: int = 0,
38+
) -> Dict[str, str]:
39+
pass
40+
41+
@abstractmethod
42+
def get_runtime(self, name: str) -> types.Runtime:
43+
pass
44+
45+
@abstractmethod
46+
def list_jobs(
47+
self, runtime: Optional[types.Runtime] = None
48+
) -> List[types.TrainJob]:
49+
pass
50+
51+
@abstractmethod
52+
def list_runtimes(self) -> List[types.Runtime]:
53+
pass
54+
55+
@abstractmethod
56+
def train(
57+
self,
58+
runtime: types.Runtime = types.DEFAULT_RUNTIME,
59+
initializer: Optional[types.Initializer] = None,
60+
trainer: Optional[types.CustomTrainer] = None,
61+
) -> str:
62+
pass
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from importlib import resources
16+
from pathlib import Path
17+
from typing import Dict, List, Optional
18+
19+
import yaml
20+
from kubeflow.trainer import models
21+
from kubeflow.trainer.api.abstract_trainer_client import AbstractTrainerClient
22+
from kubeflow.trainer.constants import constants
23+
from kubeflow.trainer.job_runners import DockerJobRunner, JobRunner
24+
from kubeflow.trainer.types import types
25+
from kubeflow.trainer.utils import utils
26+
27+
28+
class LocalTrainerClient(AbstractTrainerClient):
29+
"""LocalTrainerClient exposes functionality for running training jobs locally.
30+
31+
A Kubernetes cluster is not required.
32+
It exposes the same interface as the TrainerClient.
33+
34+
Args:
35+
local_runtimes_path: The path to the directory containing runtime YAML files.
36+
Defaults to the runtimes included with the package.
37+
job_runner: The job runner to use for local training.
38+
Options include the DockerJobRunner and PodmanJobRunner.
39+
Defaults to the Docker job runner.
40+
"""
41+
42+
def __init__(
43+
self,
44+
local_runtimes_path: Optional[Path] = None,
45+
job_runner: Optional[JobRunner] = None,
46+
):
47+
print(
48+
"Warning: LocalTrainerClient is an alpha feature for Kubeflow Trainer. "
49+
"Some features may be unstable or unimplemented."
50+
)
51+
52+
if local_runtimes_path is None:
53+
self.local_runtimes_path = (
54+
resources.files(constants.PACKAGE_NAME) / constants.LOCAL_RUNTIMES_PATH
55+
)
56+
else:
57+
self.local_runtimes_path = local_runtimes_path
58+
59+
if job_runner is None:
60+
self.job_runner = DockerJobRunner()
61+
else:
62+
self.job_runner = job_runner
63+
64+
def list_runtimes(self) -> List[types.Runtime]:
65+
"""Lists all runtimes.
66+
67+
Returns:
68+
A list of runtime objects.
69+
"""
70+
runtimes = []
71+
for cr in self.__list_runtime_crs():
72+
runtimes.append(utils.get_runtime_from_crd(cr))
73+
return runtimes
74+
75+
def get_runtime(self, name: str) -> types.Runtime:
76+
"""Get a specific runtime by name.
77+
78+
Args:
79+
name: The name of the runtime.
80+
81+
Returns:
82+
A runtime object.
83+
84+
Raises:
85+
RuntimeError: if the specified runtime cannot be found.
86+
"""
87+
for r in self.list_runtimes():
88+
if r.name == name:
89+
return r
90+
raise RuntimeError(f"No runtime found with name '{name}'")
91+
92+
def train(
93+
self,
94+
runtime: types.Runtime = types.DEFAULT_RUNTIME,
95+
initializer: Optional[types.Initializer] = None,
96+
trainer: Optional[types.CustomTrainer] = None,
97+
) -> str:
98+
"""Starts a training job.
99+
100+
Args:
101+
runtime: Config for the train job's runtime.
102+
trainer: Config for the function that encapsulates the model training process.
103+
initializer: Config for dataset and model initialization.
104+
105+
Returns:
106+
The generated name of the training job.
107+
108+
Raises:
109+
RuntimeError: if the specified runtime cannot be found,
110+
or the runtime container cannot be found,
111+
or the runtime container image is not specified.
112+
"""
113+
runtime_cr = self.__get_runtime_cr(runtime.name)
114+
if runtime_cr is None:
115+
raise RuntimeError(f"No runtime found with name '{runtime.name}'")
116+
117+
runtime_container = utils.get_runtime_trainer_container(
118+
runtime_cr.spec.template.spec.replicated_jobs
119+
)
120+
if runtime_container is None:
121+
raise RuntimeError("No runtime container found")
122+
123+
image = runtime_container.image
124+
if image is None:
125+
raise RuntimeError("No runtime container image specified")
126+
127+
if trainer and trainer.func:
128+
entrypoint, command = utils.get_entrypoint_using_train_func(
129+
runtime,
130+
trainer.func,
131+
trainer.func_args,
132+
trainer.pip_index_url,
133+
trainer.packages_to_install,
134+
)
135+
else:
136+
entrypoint = runtime_container.command
137+
command = runtime_container.args
138+
139+
if trainer and trainer.num_nodes:
140+
num_nodes = trainer.num_nodes
141+
else:
142+
num_nodes = 1
143+
144+
train_job_name = self.job_runner.create_job(
145+
image=image,
146+
entrypoint=entrypoint,
147+
command=command,
148+
num_nodes=num_nodes,
149+
framework=runtime.trainer.framework,
150+
runtime_name=runtime.name,
151+
)
152+
return train_job_name
153+
154+
def list_jobs(
155+
self, runtime: Optional[types.Runtime] = None
156+
) -> List[types.TrainJob]:
157+
"""Lists all training jobs.
158+
159+
Args:
160+
runtime: If provided, only return jobs that use the given runtime.
161+
162+
Returns:
163+
A list of training jobs.
164+
"""
165+
runtime_name = runtime.name if runtime else None
166+
container_jobs = self.job_runner.list_jobs(runtime_name)
167+
168+
train_jobs = []
169+
for container_job in container_jobs:
170+
train_jobs.append(self.__container_job_to_train_job(container_job))
171+
return train_jobs
172+
173+
def get_job(self, name: str) -> types.TrainJob:
174+
"""Get a specific training job by name.
175+
176+
Args:
177+
name: The name of the training job to get.
178+
179+
Returns:
180+
A training job.
181+
"""
182+
container_job = self.job_runner.get_job(name)
183+
return self.__container_job_to_train_job(container_job)
184+
185+
def get_job_logs(
186+
self,
187+
name: str,
188+
follow: Optional[bool] = False,
189+
step: str = constants.NODE,
190+
node_rank: int = 0,
191+
) -> Dict[str, str]:
192+
"""Gets logs for the specified training job
193+
Args:
194+
name (str): The name of the training job
195+
follow (bool): If true, follows job logs and prints them to standard out (default False)
196+
step (int): The training job step to target (default "node")
197+
node_rank (int): The node rank to retrieve logs from (default 0)
198+
199+
Returns:
200+
Dict[str, str]: The logs of the training job, where the key is the
201+
step and node rank, and the value is the logs for that node.
202+
"""
203+
return self.job_runner.get_job_logs(
204+
job_name=name, follow=follow, step=step, node_rank=node_rank
205+
)
206+
207+
def delete_job(self, name: str):
208+
"""Deletes a specific training job.
209+
210+
Args:
211+
name: The name of the training job to delete.
212+
"""
213+
self.job_runner.delete_job(job_name=name)
214+
215+
def __list_runtime_crs(self) -> List[models.TrainerV1alpha1ClusterTrainingRuntime]:
216+
runtime_crs = []
217+
for filename in self.local_runtimes_path.iterdir():
218+
with open(filename, "r") as f:
219+
cr_str = f.read()
220+
cr_dict = yaml.safe_load(cr_str)
221+
cr = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict(cr_dict)
222+
if cr is not None:
223+
runtime_crs.append(cr)
224+
return runtime_crs
225+
226+
def __get_runtime_cr(
227+
self,
228+
name: str,
229+
) -> Optional[models.TrainerV1alpha1ClusterTrainingRuntime]:
230+
for cr in self.__list_runtime_crs():
231+
if cr.metadata.name == name:
232+
return cr
233+
return None
234+
235+
def __container_job_to_train_job(
236+
self, container_job: types.ContainerJob
237+
) -> types.TrainJob:
238+
return types.TrainJob(
239+
name=container_job.name,
240+
creation_timestamp=container_job.creation_timestamp,
241+
steps=[container.to_step() for container in container_job.containers],
242+
runtime=self.get_runtime(container_job.runtime_name),
243+
status=container_job.status,
244+
)

python/kubeflow/trainer/api/trainer_client.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
import logging
1616
import multiprocessing
1717
import queue
18-
import random
19-
import string
20-
import uuid
2118
from typing import Dict, List, Optional
2219

2320
import kubeflow.trainer.models as models
21+
from kubeflow.trainer.api.abstract_trainer_client import AbstractTrainerClient
2422
from kubeflow.trainer.constants import constants
2523
from kubeflow.trainer.types import types
2624
from kubeflow.trainer.utils import utils
@@ -29,7 +27,7 @@
2927
logger = logging.getLogger(__name__)
3028

3129

32-
class TrainerClient:
30+
class TrainerClient(AbstractTrainerClient):
3331
def __init__(
3432
self,
3533
config_file: Optional[str] = None,
@@ -105,7 +103,7 @@ def list_runtimes(self) -> List[types.Runtime]:
105103
return result
106104

107105
for runtime in runtime_list.items:
108-
result.append(self.__get_runtime_from_crd(runtime))
106+
result.append(utils.get_runtime_from_crd(runtime))
109107

110108
except multiprocessing.TimeoutError:
111109
raise TimeoutError(
@@ -147,7 +145,7 @@ def get_runtime(self, name: str) -> types.Runtime:
147145
f"{self.namespace}/{name}"
148146
)
149147

150-
return self.__get_runtime_from_crd(runtime) # type: ignore
148+
return utils.get_runtime_from_crd(runtime) # type: ignore
151149

152150
def train(
153151
self,
@@ -179,7 +177,7 @@ def train(
179177

180178
# Generate unique name for the TrainJob.
181179
# TODO (andreyvelich): Discuss this TrainJob name generation.
182-
train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]
180+
train_job_name = utils.generate_train_job_name()
183181

184182
# Build the Trainer.
185183
trainer_crd = models.TrainerV1alpha1Trainer()
@@ -463,30 +461,6 @@ def delete_job(self, name: str):
463461
f"{constants.TRAINJOB_KIND} {self.namespace}/{name} has been deleted"
464462
)
465463

466-
def __get_runtime_from_crd(
467-
self,
468-
runtime_crd: models.TrainerV1alpha1ClusterTrainingRuntime,
469-
) -> types.Runtime:
470-
471-
if not (
472-
runtime_crd.metadata
473-
and runtime_crd.metadata.name
474-
and runtime_crd.spec
475-
and runtime_crd.spec.ml_policy
476-
and runtime_crd.spec.template.spec
477-
and runtime_crd.spec.template.spec.replicated_jobs
478-
):
479-
raise Exception(f"ClusterTrainingRuntime CRD is invalid: {runtime_crd}")
480-
481-
return types.Runtime(
482-
name=runtime_crd.metadata.name,
483-
trainer=utils.get_runtime_trainer(
484-
runtime_crd.spec.template.spec.replicated_jobs,
485-
runtime_crd.spec.ml_policy,
486-
runtime_crd.metadata,
487-
),
488-
)
489-
490464
def __get_trainjob_from_crd(
491465
self,
492466
trainjob_crd: models.TrainerV1alpha1TrainJob,

0 commit comments

Comments
 (0)