Skip to content

Commit 0f7a988

Browse files
feat(trainer): Refactor get_job_logs() API with Iterator (#83)
* feat(trainer): Refactor get_job_logs() API Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Fix unit tests Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Remove unused func Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Update kubeflow/trainer/api/trainer_client.py Co-authored-by: Anya Kramar <akramar@redhat.com> Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Fix print logs Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Rename TrainerClient to KubernetesBackend in tests Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Remove empty return from watch stream logs Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> --------- Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Co-authored-by: Anya Kramar <akramar@redhat.com>
1 parent f342acd commit 0f7a988

File tree

6 files changed

+113
-156
lines changed

6 files changed

+113
-156
lines changed

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
[![Join Slack](https://img.shields.io/badge/Join_Slack-blue?logo=slack)](https://www.kubeflow.org/docs/about/community/#kubeflow-slack-channels)
66
[![Coverage Status](https://coveralls.io/repos/github/kubeflow/sdk/badge.svg?branch=main)](https://coveralls.io/github/kubeflow/sdk?branch=main)
77
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/kubeflow/sdk)
8+
89
<!-- TODO(kramaranya): update when release [![Python Supported Versions](https://img.shields.io/pypi/pyversions/kubeflow.svg?color=%2334D058)](https://pypi.org/project/kubeflow/) -->
910

1011
## Overview
@@ -36,6 +37,7 @@ ML applications rather than managing complex infrastrutcure.
3637
```bash
3738
pip install git+https://github.com/kubeflow/sdk.git@main
3839
```
40+
3941
<!-- TODO(kramaranya): update before release pip install -U kubeflow -->
4042

4143
### Run your first PyTorch distributed job
@@ -49,7 +51,7 @@ def get_torch_dist():
4951
import torch.distributed as dist
5052

5153
dist.init_process_group(backend="gloo")
52-
print(f"PyTorch Distributed Environment")
54+
print("PyTorch Distributed Environment")
5355
print(f"WORLD_SIZE: {dist.get_world_size()}")
5456
print(f"RANK: {dist.get_rank()}")
5557
print(f"LOCAL_RANK: {os.environ['LOCAL_RANK']}")
@@ -70,17 +72,17 @@ job_id = TrainerClient().train(
7072
TrainerClient().wait_for_job_status(job_id)
7173

7274
# Print TrainJob logs
73-
print(TrainerClient().get_job_logs(name=job_id, node_rank=0)["node-0"])
75+
print("\n".join(TrainerClient().get_job_logs(name=job_id)))
7476
```
7577

7678
## Supported Kubeflow Projects
7779

78-
| Project | Status | Description |
79-
|-----------------------------|--------|------------------------------------------------------------|
80+
| Project | Status | Description |
81+
| --------------------------- | ---------------- | ---------------------------------------------------------- |
8082
| **Kubeflow Trainer** |**Available** | Train and fine-tune AI models with various frameworks |
81-
| **Kubeflow Katib** | 🚧 Planned | Hyperparameter optimization |
82-
| **Kubeflow Pipelines** | 🚧 Planned | Build, run, and track AI workflows |
83-
| **Kubeflow Model Registry** | 🚧 Planned | Manage model artifacts, versions and ML artifacts metadata |
83+
| **Kubeflow Katib** | 🚧 Planned | Hyperparameter optimization |
84+
| **Kubeflow Pipelines** | 🚧 Planned | Build, run, and track AI workflows |
85+
| **Kubeflow Model Registry** | 🚧 Planned | Manage model artifacts, versions and ML artifacts metadata |
8486

8587
## Community
8688

@@ -98,6 +100,7 @@ Kubeflow SDK is a community project and is still under active development. We we
98100
## Documentation
99101

100102
<!-- TODO(kramaranya): add kubeflow sdk docs -->
103+
101104
- **[Design Document](https://docs.google.com/document/d/1rX7ELAHRb_lvh0Y7BK1HBYAbA0zi9enB0F_358ZC58w/edit)**: Kubeflow SDK design proposal
102105
- **[Component Guides](https://www.kubeflow.org/docs/components/)**: Individual component documentation
103106
- **[DeepWiki](https://deepwiki.com/kubeflow/sdk)**: AI-powered repository documentation

kubeflow/trainer/api/trainer_client.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Optional, Union
16+
from typing import Optional, Union, Iterator
1717

1818
from kubeflow.trainer.constants import constants
1919
from kubeflow.trainer.types import types
@@ -120,8 +120,7 @@ def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.Train
120120
runtime: Reference to one of the existing runtimes.
121121
122122
Returns:
123-
List: List of created TrainJobs.
124-
If no TrainJob exist, an empty list is returned.
123+
List of created TrainJobs. If no TrainJob exist, an empty list is returned.
125124
126125
Raises:
127126
TimeoutError: Timeout to list TrainJobs.
@@ -148,12 +147,33 @@ def get_job(self, name: str) -> types.TrainJob:
148147
def get_job_logs(
149148
self,
150149
name: str,
150+
step: str = constants.NODE + "-0",
151151
follow: Optional[bool] = False,
152-
step: str = constants.NODE,
153-
node_rank: int = 0,
154-
) -> dict[str, str]:
155-
"""Get the logs from TrainJob"""
156-
return self.backend.get_job_logs(name=name, follow=follow, step=step, node_rank=node_rank)
152+
) -> Iterator[str]:
153+
"""Get logs from a specific step of a TrainJob.
154+
155+
You can watch for the logs in realtime as follows:
156+
```python
157+
from kubeflow.trainer import TrainerClient
158+
159+
for logline in TrainerClient().get_job_logs(name="s8d44aa4fb6d", follow=True):
160+
print(logline)
161+
```
162+
163+
Args:
164+
name: Name of the TrainJob.
165+
step: Step of the TrainJob to collect logs from, like dataset-initializer or node-0.
166+
follow: Whether to stream logs in realtime as they are produced.
167+
168+
Returns:
169+
Iterator of log lines.
170+
171+
172+
Raises:
173+
TimeoutError: Timeout to get a TrainJob.
174+
RuntimeError: Failed to get a TrainJob.
175+
"""
176+
return self.backend.get_job_logs(name=name, follow=follow, step=step)
157177

158178
def wait_for_job_status(
159179
self,

kubeflow/trainer/backends/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import abc
1616

17-
from typing import Optional, Union
17+
from typing import Optional, Union, Iterator
1818
from kubeflow.trainer.constants import constants
1919
from kubeflow.trainer.types import types
2020

@@ -47,9 +47,8 @@ def get_job_logs(
4747
self,
4848
name: str,
4949
follow: Optional[bool] = False,
50-
step: str = constants.NODE,
51-
node_rank: int = 0,
52-
) -> dict[str, str]:
50+
step: str = constants.NODE + "-0",
51+
) -> Iterator[str]:
5352
raise NotImplementedError()
5453

5554
def wait_for_job_status(

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
import copy
1616
import logging
1717
import multiprocessing
18-
import queue
1918
import random
2019
import string
2120
import time
2221
import uuid
23-
from typing import Optional, Union
22+
from typing import Optional, Union, Iterator
23+
import re
2424

2525
from kubeflow.trainer.constants import constants
2626
from kubeflow.trainer.types import types
@@ -173,7 +173,7 @@ def print_packages():
173173
)
174174

175175
self.wait_for_job_status(job_name)
176-
print(self.get_job_logs(job_name)["node-0"])
176+
print("\n".join(self.get_job_logs(name=job_name)))
177177
self.delete_job(job_name)
178178

179179
def train(
@@ -328,92 +328,48 @@ def get_job_logs(
328328
self,
329329
name: str,
330330
follow: Optional[bool] = False,
331-
step: str = constants.NODE,
332-
node_rank: int = 0,
333-
) -> dict[str, str]:
334-
"""Get the logs from TrainJob"""
335-
331+
step: str = constants.NODE + "-0",
332+
) -> Iterator[str]:
333+
"""Get the TrainJob logs"""
336334
# Get the TrainJob Pod name.
337335
pod_name = None
338336
for c in self.get_job(name).steps:
339-
if c.status != constants.POD_PENDING:
340-
if c.name == step or c.name == f"{step}-{node_rank}":
341-
pod_name = c.pod_name
337+
if c.status != constants.POD_PENDING and c.name == step:
338+
pod_name = c.pod_name
339+
break
342340
if pod_name is None:
343-
return {}
344-
345-
# Dict where key is the Pod type and value is the Pod logs.
346-
logs_dict = {}
347-
348-
# TODO (andreyvelich): Potentially, refactor this.
349-
# Support logging of multiple Pods.
350-
# TODO (andreyvelich): Currently, follow is supported only for node container.
351-
if follow and step == constants.NODE:
352-
log_streams = []
353-
log_streams.append(
354-
watch.Watch().stream(
355-
self.core_api.read_namespaced_pod_log,
356-
name=pod_name,
357-
namespace=self.namespace,
358-
container=constants.NODE,
359-
)
360-
)
361-
finished = [False] * len(log_streams)
362-
363-
# Create thread and queue per stream, for non-blocking iteration.
364-
log_queue_pool = utils.get_log_queue_pool(log_streams)
365-
366-
# Iterate over every watching pods' log queue
367-
while True:
368-
for index, log_queue in enumerate(log_queue_pool):
369-
if all(finished):
370-
break
371-
if finished[index]:
372-
continue
373-
# grouping the every 50 log lines of the same pod.
374-
for _ in range(50):
375-
try:
376-
logline = log_queue.get(timeout=1)
377-
if logline is None:
378-
finished[index] = True
379-
break
380-
# Print logs to the StdOut and update results dict.
381-
print(f"[{step}-{node_rank}]: {logline}")
382-
logs_dict[f"{step}-{node_rank}"] = (
383-
logs_dict.get(f"{step}-{node_rank}", "") + logline + "\n"
384-
)
385-
except queue.Empty:
386-
break
387-
if all(finished):
388-
return logs_dict
341+
return
389342

343+
# Remove the number for the node step.
344+
container_name = re.sub(r"-\d+$", "", step)
390345
try:
391-
if step == constants.DATASET_INITIALIZER:
392-
logs_dict[constants.DATASET_INITIALIZER] = self.core_api.read_namespaced_pod_log(
393-
name=pod_name,
394-
namespace=self.namespace,
395-
container=constants.DATASET_INITIALIZER,
396-
)
397-
elif step == constants.MODEL_INITIALIZER:
398-
logs_dict[constants.MODEL_INITIALIZER] = self.core_api.read_namespaced_pod_log(
346+
if follow:
347+
log_stream = watch.Watch().stream(
348+
self.core_api.read_namespaced_pod_log,
399349
name=pod_name,
400350
namespace=self.namespace,
401-
container=constants.MODEL_INITIALIZER,
351+
container=container_name,
352+
follow=True,
402353
)
354+
355+
# Stream logs incrementally.
356+
for logline in log_stream:
357+
yield logline # type:ignore
403358
else:
404-
logs_dict[f"{step}-{node_rank}"] = self.core_api.read_namespaced_pod_log(
359+
logs = self.core_api.read_namespaced_pod_log(
405360
name=pod_name,
406361
namespace=self.namespace,
407-
container=constants.NODE,
362+
container=container_name,
408363
)
409364

365+
for line in logs.splitlines():
366+
yield line
367+
410368
except Exception as e:
411369
raise RuntimeError(
412370
f"Failed to read logs for the pod {self.namespace}/{pod_name}"
413371
) from e
414372

415-
return logs_dict
416-
417373
def wait_for_job_status(
418374
self,
419375
name: str,

0 commit comments

Comments
 (0)