Skip to content

Commit 9985c29

Browse files
szaherandreyvelich
andauthored
chore: Add proper ruff configuration (#69)
* chore: Add proper ruff configuration Add proper ruff configuration to * support proper linting * support ruff format * support isort with ruff * update pre-commit configs Signed-off-by: Saad Zaher <eng.szaher@gmail.com> * merge with main Signed-off-by: Saad Zaher <szaher@redhat.com> * fix format Signed-off-by: Saad Zaher <szaher@redhat.com> * fix linting issues Signed-off-by: Saad Zaher <szaher@redhat.com> * Update Makefile remove reduandant `uv sync` Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Signed-off-by: Saad Zaher <szaher@redhat.com> * let linter handle line-length \n add labels next to linters codes Signed-off-by: Saad Zaher <szaher@redhat.com> * remove duplicate labels from linter Signed-off-by: Saad Zaher <szaher@redhat.com> * remove unwanted comment in pyproject.toml Signed-off-by: Saad Zaher <szaher@redhat.com> --------- Signed-off-by: Saad Zaher <eng.szaher@gmail.com> Signed-off-by: Saad Zaher <szaher@redhat.com> Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
1 parent ffc3d62 commit 9985c29

File tree

20 files changed

+677
-602
lines changed

20 files changed

+677
-602
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ repos:
77
- id: end-of-file-fixer
88
- id: trailing-whitespace
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: v0.4.4
10+
# Pin to the latest release as of today
11+
rev: v0.12.10
1112
hooks:
12-
- id: ruff
13-
exclude: |
14-
(?x)^(
15-
kubeflow/trainer/__init__.py|
16-
kubeflow/trainer/api/__init__.py|
17-
kubeflow/trainer/models/.*|
18-
)$
13+
# Lint + auto-fix (must run before format)
14+
- id: ruff-check
15+
args: [ --fix ]
16+
# Format after fixes
17+
- id: ruff-format

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ uv: ## Install UV
4848

4949
.PHONY: ruff
5050
ruff: ## Install Ruff
51-
@uvx ruff --help &> /dev/null || uv tool install ruff
51+
@uv run ruff --help &> /dev/null || uv tool install ruff
5252

5353
.PHONY: verify
5454
verify: install-dev ## install all required tools
5555
@uv lock --check
56-
@uvx ruff check --show-fixes
56+
@uv run ruff check --show-fixes --output-format=github .
57+
@uv run ruff format --check kubeflow
5758

5859
.PHONY: uv-venv
5960
uv-venv:

kubeflow/trainer/__init__.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
# limitations under the License.
1414

1515

16-
from __future__ import absolute_import
17-
1816
# Import the Kubeflow Trainer client.
1917
from kubeflow.trainer.api.trainer_client import TrainerClient # noqa: F401
2018

19+
# import backends and its associated configs
20+
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
21+
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig
22+
2123
# Import the Kubeflow Trainer constants.
2224
from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH # noqa: F401
2325

@@ -32,17 +34,12 @@
3234
Initializer,
3335
Loss,
3436
Runtime,
37+
RuntimeTrainer,
3538
TorchTuneConfig,
3639
TorchTuneInstructDataset,
37-
RuntimeTrainer,
3840
TrainerType,
3941
)
4042

41-
# import backends and its associated configs
42-
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
43-
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig
44-
45-
4643
__all__ = [
4744
"BuiltinTrainer",
4845
"CustomTrainer",

kubeflow/trainer/api/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
# ruff: noqa
22

33
# import apis into api package
4-

kubeflow/trainer/api/trainer_client.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Iterator
1516
import logging
16-
from typing import Optional, Union, Iterator
17+
from typing import Optional, Union
1718

18-
from kubeflow.trainer.constants import constants
19-
from kubeflow.trainer.types import types
2019
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
2120
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
22-
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
23-
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackendConfig
24-
21+
from kubeflow.trainer.backends.localprocess.backend import (
22+
LocalProcessBackend,
23+
LocalProcessBackendConfig,
24+
)
25+
from kubeflow.trainer.constants import constants
26+
from kubeflow.trainer.types import types
2527

2628
logger = logging.getLogger(__name__)
2729

2830

2931
class TrainerClient:
3032
def __init__(
3133
self,
32-
backend_config: Union[
33-
KubernetesBackendConfig, LocalProcessBackendConfig
34-
] = KubernetesBackendConfig(),
34+
backend_config: Union[KubernetesBackendConfig, LocalProcessBackendConfig] = None,
3535
):
3636
"""Initialize a Kubeflow Trainer client.
3737
@@ -45,12 +45,15 @@ def __init__(
4545
4646
"""
4747
# initialize training backend
48+
if not backend_config:
49+
backend_config = KubernetesBackendConfig()
50+
4851
if isinstance(backend_config, KubernetesBackendConfig):
4952
self.backend = KubernetesBackend(backend_config)
5053
elif isinstance(backend_config, LocalProcessBackendConfig):
5154
self.backend = LocalProcessBackend(backend_config)
5255
else:
53-
raise ValueError("Invalid backend config '{}'".format(backend_config))
56+
raise ValueError(f"Invalid backend config '{backend_config}'")
5457

5558
def list_runtimes(self) -> list[types.Runtime]:
5659
"""List of the available runtimes.

kubeflow/trainer/backends/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,27 @@
1313
# limitations under the License.
1414

1515
import abc
16+
from collections.abc import Iterator
17+
from typing import Optional, Union
1618

17-
from typing import Optional, Union, Iterator
1819
from kubeflow.trainer.constants import constants
1920
from kubeflow.trainer.types import types
2021

2122

2223
class ExecutionBackend(abc.ABC):
24+
@abc.abstractmethod
2325
def list_runtimes(self) -> list[types.Runtime]:
2426
raise NotImplementedError()
2527

28+
@abc.abstractmethod
2629
def get_runtime(self, name: str) -> types.Runtime:
2730
raise NotImplementedError()
2831

32+
@abc.abstractmethod
2933
def get_runtime_packages(self, runtime: types.Runtime):
3034
raise NotImplementedError()
3135

36+
@abc.abstractmethod
3237
def train(
3338
self,
3439
runtime: Optional[types.Runtime] = None,
@@ -37,12 +42,15 @@ def train(
3742
) -> str:
3843
raise NotImplementedError()
3944

45+
@abc.abstractmethod
4046
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
4147
raise NotImplementedError()
4248

49+
@abc.abstractmethod
4350
def get_job(self, name: str) -> types.TrainJob:
4451
raise NotImplementedError()
4552

53+
@abc.abstractmethod
4654
def get_job_logs(
4755
self,
4856
name: str,
@@ -51,6 +59,7 @@ def get_job_logs(
5159
) -> Iterator[str]:
5260
raise NotImplementedError()
5361

62+
@abc.abstractmethod
5463
def wait_for_job_status(
5564
self,
5665
name: str,
@@ -60,5 +69,6 @@ def wait_for_job_status(
6069
) -> types.TrainJob:
6170
raise NotImplementedError()
6271

72+
@abc.abstractmethod
6373
def delete_job(self, name: str):
6474
raise NotImplementedError()

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Iterator
1516
import copy
1617
import logging
1718
import multiprocessing
1819
import random
20+
import re
1921
import string
2022
import time
23+
from typing import Optional, Union
2124
import uuid
22-
from typing import Optional, Union, Iterator
23-
import re
2425

25-
from kubeflow.trainer.constants import constants
26-
from kubeflow.trainer.types import types
27-
from kubeflow.trainer.utils import utils
2826
from kubeflow_trainer_api import models
2927
from kubernetes import client, config, watch
28+
3029
from kubeflow.trainer.backends.base import ExecutionBackend
3130
from kubeflow.trainer.backends.kubernetes import types as k8s_types
31+
from kubeflow.trainer.constants import constants
32+
from kubeflow.trainer.types import types
33+
from kubeflow.trainer.utils import utils
3234

3335
logger = logging.getLogger(__name__)
3436

@@ -141,8 +143,8 @@ def get_runtime_packages(self, runtime: types.Runtime):
141143
runtime_copy.trainer.set_command(tuple(mpi_command))
142144

143145
def print_packages():
144-
import subprocess
145146
import shutil
147+
import subprocess
146148
import sys
147149

148150
# Print Python version.
@@ -353,17 +355,15 @@ def get_job_logs(
353355
)
354356

355357
# Stream logs incrementally.
356-
for logline in log_stream:
357-
yield logline # type:ignore
358+
yield from log_stream
358359
else:
359360
logs = self.core_api.read_namespaced_pod_log(
360361
name=pod_name,
361362
namespace=self.namespace,
362363
container=container_name,
363364
)
364365

365-
for line in logs.splitlines():
366-
yield line
366+
yield from logs.splitlines()
367367

368368
except Exception as e:
369369
raise RuntimeError(
@@ -554,9 +554,12 @@ def __get_trainjob_from_crd(
554554
# Update the TrainJob status from its conditions.
555555
if trainjob_crd.status and trainjob_crd.status.conditions:
556556
for c in trainjob_crd.status.conditions:
557-
if c.type == constants.TRAINJOB_COMPLETE and c.status == "True":
558-
trainjob.status = c.type
559-
elif c.type == constants.TRAINJOB_FAILED and c.status == "True":
557+
if (
558+
c.type == constants.TRAINJOB_COMPLETE
559+
and c.status == "True"
560+
or c.type == constants.TRAINJOB_FAILED
561+
and c.status == "True"
562+
):
560563
trainjob.status = c.type
561564
else:
562565
# The TrainJob running status is defined when all training node (e.g. Pods) are

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,31 @@
1919
It tests KubernetesBackend's behavior across job listing, resource creation etc
2020
"""
2121

22+
from dataclasses import asdict
2223
import datetime
2324
import multiprocessing
2425
import random
2526
import string
26-
import uuid
27-
from dataclasses import asdict
2827
from typing import Optional
2928
from unittest.mock import Mock, patch
29+
import uuid
3030

31-
import pytest
3231
from kubeflow_trainer_api import models
32+
import pytest
3333

34-
from kubeflow.trainer.constants import constants
35-
from kubeflow.trainer.types import types
36-
from kubeflow.trainer.utils import utils
3734
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
3835
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
39-
from kubeflow.trainer.test.common import TestCase
36+
from kubeflow.trainer.constants import constants
4037
from kubeflow.trainer.test.common import (
41-
SUCCESS,
42-
FAILED,
4338
DEFAULT_NAMESPACE,
44-
TIMEOUT,
39+
FAILED,
4540
RUNTIME,
41+
SUCCESS,
42+
TIMEOUT,
43+
TestCase,
4644
)
45+
from kubeflow.trainer.types import types
46+
from kubeflow.trainer.utils import utils
4747

4848
# In all tests runtime name is equal to the framework name.
4949
TORCH_RUNTIME = "torch"
@@ -788,7 +788,6 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
788788
},
789789
expected_error=ValueError,
790790
),
791-
792791
],
793792
)
794793
def test_train(kubernetes_backend, test_case):

kubeflow/trainer/backends/kubernetes/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from typing import Optional
16+
1617
from kubernetes import client
1718
from pydantic import BaseModel
1819

0 commit comments

Comments
 (0)