Skip to content

Commit

Permalink
Apply ruff pre-commit hook
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 committed Aug 10, 2023
1 parent af98bdd commit 813ad66
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 9 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ default_language_version:
# a consistency with the official tfrecord preprocessing scripts
exclude: "^(streaming/text/convert/enwiki/)"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.282
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,25 @@ split_penalty_logical_operator = 300
# Use the Tab character for indentation.
use_tabs = false

[tool.ruff]
select = [
"C4",
# TODO port pydocstyle
# "D", # pydocstyle
"PERF",
]

ignore = [
"C408",
"PERF2",
"PERF4",
]
exclude = [
"build/**",
"docs/**",
"node_modules/**",
]

# Ignore directories
[tool.yapfignore]
ignore_patterns = [
Expand Down
2 changes: 1 addition & 1 deletion scripts/shuffle/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def wrap(func: Callable):

callers = list(map(wrap, get_shuffles))

text = ' '.join(map(lambda s: s.rjust(10), names))
text = ' '.join((s.rjust(10) for s in names))

print(f'{"power".rjust(5)} {"samples".rjust(14)} ' + text)
for mul_power in range(args.min_power * args.power_interval,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
'sphinx-tabs==3.4.1',
]

extra_deps['all'] = sorted(set(dep for deps in extra_deps.values() for dep in deps))
extra_deps['all'] = sorted({dep for deps in extra_deps.values() for dep in deps})

package_name = os.environ.get('MOSAIC_PACKAGE_NAME', 'mosaicml-streaming')

Expand Down
2 changes: 1 addition & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def __init__(self,
if self.cache_limit:
if isinstance(self.cache_limit, str):
self.cache_limit = bytes_to_int(self.cache_limit)
min_cache_usage = sum(map(lambda stream: stream.get_index_size(), streams))
min_cache_usage = sum((stream.get_index_size() for stream in streams))
if self.cache_limit <= min_cache_usage:
raise ValueError(f'Minimum cache usage ({min_cache_usage} bytes) is larger than ' +
f'the cache limit ({self.cache_limit} bytes). Please raise ' +
Expand Down
4 changes: 2 additions & 2 deletions streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class NDArray(Encoding):
}

# Shape dtype -> integer <4.
_shape_dtype2int = dict((v, k) for k, v in _int2shape_dtype.items())
_shape_dtype2int = {v: k for k, v in _int2shape_dtype.items()}

# Integer <256 -> value dtype.
_int2value_dtype = {
Expand All @@ -140,7 +140,7 @@ class NDArray(Encoding):
}

# Value dtype -> integer <256.
_value_dtype2int = dict((v, k) for k, v in _int2value_dtype.items())
_value_dtype2int = {v: k for k, v in _int2value_dtype.items()}

@classmethod
def _get_static_size(cls, dtype: Optional[str], shape: Optional[Tuple[int]]) -> Optional[int]:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_eviction.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def validate(remote: str, local: str, dataset: StreamingDataset, keep_zip: bool,
if dataset.shards[0].compression:
# Local has raw, remote has zip.
assert ops(set(os.listdir(local)),
set(map(lambda f: f.replace('.zstd', ''), os.listdir(remote))))
{f.replace('.zstd', '') for f in os.listdir(remote)})
else:
# Local has raw, remote has raw.
assert ops(set(os.listdir(local)), set(os.listdir(remote)))
Expand Down Expand Up @@ -126,7 +126,7 @@ def cache_limit_too_low(remote: str, local: str, keep_zip: bool):


@pytest.mark.usefixtures('local_remote_dir')
@pytest.mark.parametrize('func', [f for f in funcs])
@pytest.mark.parametrize('func', list(funcs))
def test_eviction_nozip(local_remote_dir: Tuple[str, str], func: Any):
num_samples = 5_000
local, remote = local_remote_dir
Expand All @@ -148,7 +148,7 @@ def test_eviction_nozip(local_remote_dir: Tuple[str, str], func: Any):


@pytest.mark.usefixtures('local_remote_dir')
@pytest.mark.parametrize('func', [f for f in funcs])
@pytest.mark.parametrize('func', list(funcs))
def test_eviction_zip_nokeep(local_remote_dir: Tuple[str, str], func: Any):
num_samples = 5_000
local, remote = local_remote_dir
Expand All @@ -170,7 +170,7 @@ def test_eviction_zip_nokeep(local_remote_dir: Tuple[str, str], func: Any):


@pytest.mark.usefixtures('local_remote_dir')
@pytest.mark.parametrize('func', [f for f in funcs])
@pytest.mark.parametrize('func', list(funcs))
def test_eviction_zip_keep(local_remote_dir: Tuple[str, str], func: Any):
num_samples = 5_000
local, remote = local_remote_dir
Expand Down

0 comments on commit 813ad66

Please sign in to comment.