|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +from collections.abc import Iterator |
15 | 16 | import copy |
16 | 17 | import logging |
17 | 18 | import multiprocessing |
18 | 19 | import random |
| 20 | +import re |
19 | 21 | import string |
20 | 22 | import time |
| 23 | +from typing import Optional, Union |
21 | 24 | import uuid |
22 | | -from typing import Optional, Union, Iterator |
23 | | -import re |
24 | 25 |
|
25 | | -from kubeflow.trainer.constants import constants |
26 | | -from kubeflow.trainer.types import types |
27 | | -from kubeflow.trainer.utils import utils |
28 | 26 | from kubeflow_trainer_api import models |
29 | 27 | from kubernetes import client, config, watch |
| 28 | + |
30 | 29 | from kubeflow.trainer.backends.base import ExecutionBackend |
31 | 30 | from kubeflow.trainer.backends.kubernetes import types as k8s_types |
| 31 | +from kubeflow.trainer.constants import constants |
| 32 | +from kubeflow.trainer.types import types |
| 33 | +from kubeflow.trainer.utils import utils |
32 | 34 |
|
33 | 35 | logger = logging.getLogger(__name__) |
34 | 36 |
|
@@ -141,8 +143,8 @@ def get_runtime_packages(self, runtime: types.Runtime): |
141 | 143 | runtime_copy.trainer.set_command(tuple(mpi_command)) |
142 | 144 |
|
143 | 145 | def print_packages(): |
144 | | - import subprocess |
145 | 146 | import shutil |
| 147 | + import subprocess |
146 | 148 | import sys |
147 | 149 |
|
148 | 150 | # Print Python version. |
@@ -353,17 +355,15 @@ def get_job_logs( |
353 | 355 | ) |
354 | 356 |
|
355 | 357 | # Stream logs incrementally. |
356 | | - for logline in log_stream: |
357 | | - yield logline # type:ignore |
| 358 | + yield from log_stream |
358 | 359 | else: |
359 | 360 | logs = self.core_api.read_namespaced_pod_log( |
360 | 361 | name=pod_name, |
361 | 362 | namespace=self.namespace, |
362 | 363 | container=container_name, |
363 | 364 | ) |
364 | 365 |
|
365 | | - for line in logs.splitlines(): |
366 | | - yield line |
| 366 | + yield from logs.splitlines() |
367 | 367 |
|
368 | 368 | except Exception as e: |
369 | 369 | raise RuntimeError( |
@@ -554,9 +554,12 @@ def __get_trainjob_from_crd( |
554 | 554 | # Update the TrainJob status from its conditions. |
555 | 555 | if trainjob_crd.status and trainjob_crd.status.conditions: |
556 | 556 | for c in trainjob_crd.status.conditions: |
557 | | - if c.type == constants.TRAINJOB_COMPLETE and c.status == "True": |
558 | | - trainjob.status = c.type |
559 | | - elif c.type == constants.TRAINJOB_FAILED and c.status == "True": |
| 557 | + if ( |
| 558 | + c.type == constants.TRAINJOB_COMPLETE |
| 559 | + and c.status == "True" |
| 560 | + or c.type == constants.TRAINJOB_FAILED |
| 561 | + and c.status == "True" |
| 562 | + ): |
560 | 563 | trainjob.status = c.type |
561 | 564 | else: |
562 | 565 | # The TrainJob running status is defined when all training node (e.g. Pods) are |
|
0 commit comments