Skip to content

Commit ffc3d62

Browse files
szaherandreyvelich
andauthored
feat: Implement TrainerClient Backends & Local Process (#33)
* Implement TrainerClient Backends & Local Process Signed-off-by: Saad Zaher <szaher@redhat.com> * Implement Job Cancellation Signed-off-by: Saad Zaher <szaher@redhat.com> * update local job to add resouce limitation in k8s style Signed-off-by: Saad Zaher <szaher@redhat.com> * Update python/kubeflow/trainer/api/trainer_client.py Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Signed-off-by: Saad Zaher <szaher@redhat.com> * Fix linting issues Signed-off-by: Saad Zaher <eng.szaher@gmail.com> * fix unit tests Signed-off-by: Saad Zaher <eng.szaher@gmail.com> * add support wait_for_job_status Signed-off-by: Saad Zaher <eng.szaher@gmail.com> * Update data types Signed-off-by: Saad Zaher <szaher@redhat.com> * fix merge conflict Signed-off-by: Saad Zaher <szaher@redhat.com> * fix unit tests Signed-off-by: Saad Zaher <szaher@redhat.com> * remove TypeAlias Signed-off-by: Saad Zaher <szaher@redhat.com> * Replace TRAINER_BACKEND_REGISTRY with TRAINER_BACKEND Signed-off-by: Saad Zaher <szaher@redhat.com> * Update kubeflow/trainer/api/trainer_client.py Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Signed-off-by: Saad Zaher <szaher@redhat.com> * Update kubeflow/trainer/api/trainer_client.py Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Signed-off-by: Saad Zaher <szaher@redhat.com> * Restructure training backends into separate dirs Signed-off-by: Saad Zaher <szaher@redhat.com> * Update kubeflow/trainer/api/trainer_client.py Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Signed-off-by: Saad Zaher <szaher@redhat.com> * add get_runtime_packages as not supported by local-exec Signed-off-by: Saad Zaher <szaher@redhat.com> * move backends and its configs to kubeflow.trainer Signed-off-by: Saad Zaher <szaher@redhat.com> * fix typo in delete_job Signed-off-by: Saad Zaher <szaher@redhat.com> * Move local_runtimes to constants * Move local_runtimes to constants * allow list_jobs to filter by runtime * keep runtime ref in __local_jobs Signed-off-by: Saad Zaher <szaher@redhat.com> * use google style docstring for LocalJob Signed-off-by: Saad Zaher <szaher@redhat.com> * remove debug opt from LocalProcessConfig Signed-off-by: Saad Zaher <szaher@redhat.com> * only use imports from kubeflow.trainer for backends Signed-off-by: Saad Zaher <szaher@redhat.com> * upload local-exec to use only one step While I believe in simplicity and diving this into steps makes it easier for debugging and extensibility. Addressing comments on this PR consolidating all train job scripts into one and running it as single step to match k8s. Signed-off-by: Saad Zaher <szaher@redhat.com> * optimize loops when getting runtime Signed-off-by: Saad Zaher <szaher@redhat.com> * add LocalRuntimeTrainer Signed-off-by: Saad Zaher <szaher@redhat.com> * rename cleanup config item to cleanup_venv Signed-off-by: Saad Zaher <szaher@redhat.com> * convert local runtime to runtime Signed-off-by: Saad Zaher <szaher@redhat.com> * convert runtimes before returning Signed-off-by: Saad Zaher <szaher@redhat.com> * fix get_job_logs to align with parent interface Signed-off-by: Saad Zaher <szaher@redhat.com> * rename get_runtime_trainer func Signed-off-by: Saad Zaher <szaher@redhat.com> * rename get_training_job_command to get_local_train_job_script Signed-off-by: Saad Zaher <szaher@redhat.com> * Ignore failures in Coveralls action Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> --------- Signed-off-by: Saad Zaher <szaher@redhat.com> Signed-off-by: Saad Zaher <eng.szaher@gmail.com> Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
1 parent 6709dcf commit ffc3d62

File tree

9 files changed

+891
-2
lines changed

9 files changed

+891
-2
lines changed

.github/workflows/test-python.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
strategy:
1616
fail-fast: false
1717
matrix:
18-
python-version: ['3.9', '3.11']
18+
python-version: ["3.9", "3.11"]
1919

2020
name: Test (Python ${{ matrix.python-version }})
2121

@@ -36,6 +36,7 @@ jobs:
3636
3737
- name: Upload coverage to Coveralls
3838
uses: coverallsapp/github-action@v2
39+
continue-on-error: true
3940
with:
4041
github-token: ${{ secrets.GITHUB_TOKEN }}
4142
parallel: true
@@ -48,6 +49,7 @@ jobs:
4849
steps:
4950
- name: Close parallel build
5051
uses: coverallsapp/github-action@v2
52+
continue-on-error: true
5153
with:
5254
github-token: ${{ secrets.GITHUB_TOKEN }}
5355
parallel-finished: true

kubeflow/trainer/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
TrainerType,
3939
)
4040

41+
# import backends and its associated configs
42+
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
43+
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig
44+
45+
4146
__all__ = [
4247
"BuiltinTrainer",
4348
"CustomTrainer",
@@ -55,4 +60,6 @@
5560
"RuntimeTrainer",
5661
"TrainerClient",
5762
"TrainerType",
63+
"LocalProcessBackendConfig",
64+
"KubernetesBackendConfig",
5865
]

kubeflow/trainer/api/trainer_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from kubeflow.trainer.types import types
2020
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
2121
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
22+
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
23+
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackendConfig
2224

2325

2426
logger = logging.getLogger(__name__)
@@ -27,7 +29,9 @@
2729
class TrainerClient:
2830
def __init__(
2931
self,
30-
backend_config: KubernetesBackendConfig = KubernetesBackendConfig(),
32+
backend_config: Union[
33+
KubernetesBackendConfig, LocalProcessBackendConfig
34+
] = KubernetesBackendConfig(),
3135
):
3236
"""Initialize a Kubeflow Trainer client.
3337
@@ -43,6 +47,8 @@ def __init__(
4347
# initialize training backend
4448
if isinstance(backend_config, KubernetesBackendConfig):
4549
self.backend = KubernetesBackend(backend_config)
50+
elif isinstance(backend_config, LocalProcessBackendConfig):
51+
self.backend = LocalProcessBackend(backend_config)
4652
else:
4753
raise ValueError("Invalid backend config '{}'".format(backend_config))
4854

kubeflow/trainer/backends/localprocess/__init__.py

Whitespace-only changes.
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
import string
16+
import tempfile
17+
import uuid
18+
import random
19+
from datetime import datetime
20+
from typing import List, Optional, Set, Union, Iterator
21+
22+
from kubeflow.trainer.constants import constants
23+
from kubeflow.trainer.types import types
24+
from kubeflow.trainer.backends.base import ExecutionBackend
25+
from kubeflow.trainer.backends.localprocess.types import (
26+
LocalProcessBackendConfig,
27+
LocalBackendJobs,
28+
LocalBackendStep,
29+
)
30+
from kubeflow.trainer.backends.localprocess.constants import local_runtimes
31+
from kubeflow.trainer.backends.localprocess.job import LocalJob
32+
from kubeflow.trainer.backends.localprocess import utils as local_utils
33+
34+
logger = logging.getLogger(__name__)
35+
36+
37+
class LocalProcessBackend(ExecutionBackend):
38+
def __init__(
39+
self,
40+
cfg: LocalProcessBackendConfig,
41+
):
42+
# list of running subprocesses
43+
self.__local_jobs: List[LocalBackendJobs] = []
44+
self.cfg = cfg
45+
46+
def list_runtimes(self) -> List[types.Runtime]:
47+
return [self.__convert_local_runtime_to_runtime(local_runtime=rt) for rt in local_runtimes]
48+
49+
def get_runtime(self, name: str) -> types.Runtime:
50+
runtime = next(
51+
(
52+
self.__convert_local_runtime_to_runtime(rt)
53+
for rt in local_runtimes
54+
if rt.name == name
55+
),
56+
None,
57+
)
58+
if not runtime:
59+
raise ValueError(f"Runtime '{name}' not found.")
60+
61+
return runtime
62+
63+
def get_runtime_packages(self, runtime: types.Runtime):
64+
runtime = next((rt for rt in local_runtimes if rt.name == runtime.name), None)
65+
if not runtime:
66+
raise ValueError(f"Runtime '{runtime.name}' not found.")
67+
68+
return runtime.trainer.packages
69+
70+
def train(
71+
self,
72+
runtime: Optional[types.Runtime] = None,
73+
initializer: Optional[types.Initializer] = None,
74+
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
75+
) -> str:
76+
# set train job name
77+
train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]
78+
# localprocess backend only supports CustomTrainer
79+
if not isinstance(trainer, types.CustomTrainer):
80+
raise ValueError("CustomTrainer must be set with LocalProcessBackend")
81+
82+
# create temp dir
83+
venv_dir = tempfile.mkdtemp(prefix=train_job_name)
84+
logger.debug("operating in {}".format(venv_dir))
85+
86+
runtime.trainer = local_utils.get_local_runtime_trainer(
87+
runtime_name=runtime.name,
88+
venv_dir=venv_dir,
89+
framework=runtime.trainer.framework,
90+
)
91+
92+
# build training job command
93+
training_command = local_utils.get_local_train_job_script(
94+
trainer=trainer,
95+
runtime=runtime,
96+
train_job_name=train_job_name,
97+
venv_dir=venv_dir,
98+
cleanup_venv=self.cfg.cleanup_venv,
99+
)
100+
101+
# set the command in the runtime trainer
102+
runtime.trainer.set_command(training_command)
103+
104+
# create subprocess object
105+
train_job = LocalJob(
106+
name="{}-train".format(train_job_name),
107+
command=training_command,
108+
execution_dir=venv_dir,
109+
env=trainer.env,
110+
dependencies=[],
111+
)
112+
113+
self.__register_job(
114+
train_job_name=train_job_name,
115+
step_name="train",
116+
job=train_job,
117+
runtime=runtime,
118+
)
119+
# start the job.
120+
train_job.start()
121+
122+
return train_job_name
123+
124+
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.TrainJob]:
125+
result = []
126+
127+
for _job in self.__local_jobs:
128+
if runtime and _job.runtime.name != runtime.name:
129+
continue
130+
result.append(
131+
types.TrainJob(
132+
name=_job.name,
133+
creation_timestamp=_job.created,
134+
runtime=runtime,
135+
num_nodes=1,
136+
steps=[
137+
types.Step(name=s.step_name, pod_name=s.step_name, status=s.job.status)
138+
for s in _job.steps
139+
],
140+
)
141+
)
142+
return result
143+
144+
def get_job(self, name: str) -> Optional[types.TrainJob]:
145+
_job = next((j for j in self.__local_jobs if j.name == name), None)
146+
if _job is None:
147+
raise ValueError("No TrainJob with name '%s'" % name)
148+
149+
# check and set the correct job status to match `TrainerClient` supported statuses
150+
status = self.__get_job_status(_job)
151+
152+
return types.TrainJob(
153+
name=_job.name,
154+
creation_timestamp=_job.created,
155+
steps=[
156+
types.Step(name=_step.step_name, pod_name=_step.step_name, status=_step.job.status)
157+
for _step in _job.steps
158+
],
159+
runtime=_job.runtime,
160+
num_nodes=1,
161+
status=status,
162+
)
163+
164+
def get_job_logs(
165+
self,
166+
name: str,
167+
step: str = constants.NODE + "-0",
168+
follow: Optional[bool] = False,
169+
) -> Iterator[str]:
170+
_job = [j for j in self.__local_jobs if j.name == name]
171+
if not _job:
172+
raise ValueError("No TrainJob with name '%s'" % name)
173+
174+
want_all_steps = step == constants.NODE + "-0"
175+
176+
for _step in _job[0].steps:
177+
if not want_all_steps and _step.step_name != step:
178+
continue
179+
# Flatten the generator and pass through flags so it behaves as expected
180+
# (adjust args if stream_logs has different signature)
181+
yield from _step.job.logs(follow=follow)
182+
183+
def wait_for_job_status(
184+
self,
185+
name: str,
186+
status: Set[str] = {constants.TRAINJOB_COMPLETE},
187+
timeout: int = 600,
188+
polling_interval: int = 2,
189+
) -> types.TrainJob:
190+
# find first match or fallback
191+
_job = next((_job for _job in self.__local_jobs if _job.name == name), None)
192+
193+
if _job is None:
194+
raise ValueError("No TrainJob with name '%s'" % name)
195+
# find a better implementation for this
196+
for _step in _job.steps:
197+
if _step.job.status in [constants.TRAINJOB_RUNNING, constants.TRAINJOB_CREATED]:
198+
_step.job.join(timeout=timeout)
199+
return self.get_job(name)
200+
201+
def delete_job(self, name: str):
202+
# find job first.
203+
_job = next((j for j in self.__local_jobs if j.name == name), None)
204+
if _job is None:
205+
raise ValueError("No TrainJob with name '%s'" % name)
206+
207+
# cancel all nested step jobs in target job
208+
_ = [step.job.cancel() for step in _job.steps]
209+
# remove the job from the list of jobs
210+
self.__local_jobs.remove(_job)
211+
212+
def __get_job_status(self, job: LocalBackendJobs) -> str:
213+
statuses = [_step.job.status for _step in job.steps]
214+
# if status is running or failed will take precedence over completed
215+
if constants.TRAINJOB_FAILED in statuses:
216+
status = constants.TRAINJOB_FAILED
217+
elif constants.TRAINJOB_RUNNING in statuses:
218+
status = constants.TRAINJOB_RUNNING
219+
elif constants.TRAINJOB_CREATED in statuses:
220+
status = constants.TRAINJOB_CREATED
221+
else:
222+
status = constants.TRAINJOB_CREATED
223+
224+
return status
225+
226+
def __register_job(
227+
self,
228+
train_job_name: str,
229+
step_name: str,
230+
job: LocalJob,
231+
runtime: types.Runtime = None,
232+
):
233+
_job = [j for j in self.__local_jobs if j.name == train_job_name]
234+
if not _job:
235+
_job = LocalBackendJobs(name=train_job_name, runtime=runtime, created=datetime.now())
236+
self.__local_jobs.append(_job)
237+
else:
238+
_job = _job[0]
239+
_step = [s for s in _job.steps if s.step_name == step_name]
240+
if not _step:
241+
_step = LocalBackendStep(step_name=step_name, job=job)
242+
_job.steps.append(_step)
243+
else:
244+
logger.warning("Step '{}' already registered.".format(step_name))
245+
246+
def __convert_local_runtime_to_runtime(self, local_runtime) -> types.Runtime:
247+
return types.Runtime(
248+
name=local_runtime.name,
249+
trainer=types.RuntimeTrainer(
250+
trainer_type=local_runtime.trainer.trainer_type,
251+
framework=local_runtime.trainer.framework,
252+
num_nodes=local_runtime.trainer.num_nodes,
253+
device_count=local_runtime.trainer.device_count,
254+
device=local_runtime.trainer.device,
255+
),
256+
pretrained_model=local_runtime.pretrained_model,
257+
)

0 commit comments

Comments
 (0)