Skip to content

Commit

Permalink
Use ruff for formatting (#6434)
Browse files Browse the repository at this point in the history
* Use `ruff` for formatting

* Updat quality dependencies and lint setup.py

* Update pre-commit-config

* Small fix
  • Loading branch information
mariosasko authored Nov 21, 2023
1 parent c65315e commit 1a1e741
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 47 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ jobs:
pip install .[quality]
- name: Check quality
run: |
black --check tests src benchmarks metrics
ruff tests src benchmarks metrics
ruff check tests src benchmarks metrics utils setup.py # linter
ruff format --check tests src benchmarks metrics utils setup.py # formatter
test:
needs: check_code_quality
Expand Down
20 changes: 7 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
- repo: https://github.com/charliermarsh/ruff-pre-commit # https://github.com/charliermarsh/ruff#usage
rev: 'v0.1.5'
hooks:
- id: black
language_version: python3
types: [python]
stages: [commit]
args: ["--config", "pyproject.toml", "tests", "src", "benchmarks", "metrics"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.255'
hooks:
- id: ruff
stages: [commit]
args: [ "--config", "pyproject.toml", "tests", "src", "benchmarks", "metrics", "--fix"]
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ check_dirs := tests src benchmarks metrics utils
# Check that source code meets quality standards

quality:
black --check $(check_dirs) setup.py
ruff $(check_dirs) setup.py
ruff check $(check_dirs) setup.py # linter
ruff format --check $(check_dirs) setup.py # formatter

# Format source code automatically

style:
black tests src benchmarks metrics setup.py
ruff $(check_dirs) setup.py --fix
ruff check --fix $(check_dirs) setup.py # linter
ruff format $(check_dirs) setup.py # formatter

# Run tests for the library

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
TESTS_REQUIRE.extend(VISION_REQUIRE)
TESTS_REQUIRE.extend(AUDIO_REQUIRE)

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241", "pyyaml>=5.3.1"]
QUALITY_REQUIRE = ["ruff>=0.1.5"]

DOCS_REQUIRE = [
# Might need to add doc-builder and some specific deps in the future
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,7 +3101,8 @@ def load_processed_shard_from_cache(shard_kwargs):
else:

def format_cache_file_name(
cache_file_name: Optional[str], rank: Union[int, Literal["*"]] # noqa: F722
cache_file_name: Optional[str],
rank: Union[int, Literal["*"]], # noqa: F722
) -> Optional[str]:
if not cache_file_name:
return cache_file_name
Expand Down
10 changes: 6 additions & 4 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def _convert_to_arrow(
Drop the last batch if it is smaller than `batch_size`.
"""
if batch_size is None or batch_size <= 0:
yield "all", pa.Table.from_pylist(
cast_to_python_objects([example for _, example in iterable], only_1d_for_numpy=True)
yield (
"all",
pa.Table.from_pylist(cast_to_python_objects([example for _, example in iterable], only_1d_for_numpy=True)),
)
return
iterator = iter(iterable)
Expand Down Expand Up @@ -1112,8 +1113,9 @@ def __iter__(self):
# Then for each example, `TypedExamplesIterable` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_example`.
for key, example in self.ex_iterable:
yield key, _apply_feature_types_on_example(
example, self.features, token_per_repo_id=self.token_per_repo_id
yield (
key,
_apply_feature_types_on_example(example, self.features, token_per_repo_id=self.token_per_repo_id),
)

def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]:
Expand Down
8 changes: 2 additions & 6 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,9 +1493,7 @@ def dataset_module_factory(
download_config=download_config,
download_mode=download_mode,
).get_module()
except (
Exception
) as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
except Exception as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
try:
return CachedDatasetModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
except Exception: # noqa if it's not in the cache, then it doesn't exist.
Expand Down Expand Up @@ -1598,9 +1596,7 @@ def metric_module_factory(
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
).get_module()
except (
Exception
) as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
except Exception as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
try:
return CachedMetricModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
except Exception: # noqa if it's not in the cache, then it doesn't exist.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,15 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_label = {"label": os.path.basename(os.path.dirname(original_file))}
else:
sample_label = {}
yield file_idx, {
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_file_or_dir,
**sample_metadata,
**sample_label,
}
yield (
file_idx,
{
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_file_or_dir,
**sample_metadata,
**sample_label,
},
)
file_idx += 1
else:
for downloaded_dir_file in downloaded_file_or_dir:
Expand Down Expand Up @@ -391,10 +394,13 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_label = {"label": os.path.basename(os.path.dirname(downloaded_dir_file))}
else:
sample_label = {}
yield file_idx, {
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_dir_file,
**sample_metadata,
**sample_label,
}
yield (
file_idx,
{
**sample_empty_metadata,
self.BASE_COLUMN_NAME: downloaded_dir_file,
**sample_metadata,
**sample_label,
},
)
file_idx += 1
3 changes: 3 additions & 0 deletions src/datasets/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class SplitBase(metaclass=abc.ABCMeta):
to define which files to read and how to skip examples within file.
"""

# pylint: enable=line-too-long

@abc.abstractmethod
Expand Down Expand Up @@ -265,6 +266,7 @@ class PercentSlice(metaclass=PercentSliceMeta):
[guide on splits](../loading#slice-splits)
for more information.
"""

# pylint: enable=line-too-long
pass

Expand Down Expand Up @@ -438,6 +440,7 @@ class Split:
... )
```
"""

# pylint: enable=line-too-long
TRAIN = NamedSplit("train")
TEST = NamedSplit("test")
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __enter__(self):
# We don't check for the name of the global, but rather if its value *is* "os" or "os.path".
# This allows to patch renamed modules like "from os import path as ospath".
if obj_attr is submodule or (
(isinstance(obj_attr, _PatchedModuleObj) and obj_attr._original_module is submodule)
isinstance(obj_attr, _PatchedModuleObj) and obj_attr._original_module is submodule
):
self.original[attr] = obj_attr
# patch at top level
Expand Down
4 changes: 1 addition & 3 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3066,9 +3066,7 @@ def test_concatenate_mixed_memory_and_disk(self):
cache_file_name=os.path.join(tmp_dir, "d1.arrow")
) as dset1, Dataset.from_dict(data2, info=info2).map(
cache_file_name=os.path.join(tmp_dir, "d2.arrow")
) as dset2, Dataset.from_dict(
data3
) as dset3:
) as dset2, Dataset.from_dict(data3) as dset3:
with concatenate_datasets([dset1, dset2, dset3]) as concatenated_dset:
self.assertEqual(len(concatenated_dset), len(dset1) + len(dset2) + len(dset3))
self.assertListEqual(concatenated_dset["id"], dset1["id"] + dset2["id"] + dset3["id"])
Expand Down
1 change: 1 addition & 0 deletions tests/test_readme_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# @pytest.fixture
# def example_yaml_structure():


example_yaml_structure = yaml.safe_load(
"""\
name: ""
Expand Down

0 comments on commit 1a1e741

Please sign in to comment.