Skip to content

Commit

Permalink
Code review update
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Dec 10, 2024
1 parent 1d3827e commit bc1ea48
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/datachain/toolkit/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def train_test_split(
if seed is not None:
uniform_seed = random.Random(seed).randrange(1, RESOLUTION + 1) # noqa: S311
rand_col = (rand_col % (RESOLUTION + 1)) * uniform_seed # type: ignore[assignment]
rand_col = rand_col % (RESOLUTION + 1) # type: ignore[assignment]
rand_col = rand_col % RESOLUTION # type: ignore[assignment]

return [
dc.filter(
rand_col >= round(sum(weights_normalized[:index]) * RESOLUTION),
rand_col < round(sum(weights_normalized[: index + 1]) * RESOLUTION),
rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION + 1)),
rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION + 1)),
)
for index, _ in enumerate(weights_normalized)
]
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,16 +658,16 @@ def not_random_ds(test_session):
# will always return columns in the `sys__id` order if no seed is provided.
return DataChain.from_records(
[
{"sys__id": 1, "sys__rand": 3398273711618747276, "fib": 0},
{"sys__id": 1, "sys__rand": 8025184816406567794, "fib": 0},
{"sys__id": 2, "sys__rand": 8264763963075908010, "fib": 1},
{"sys__id": 3, "sys__rand": 2466105069438384471, "fib": 1},
{"sys__id": 3, "sys__rand": 338514328625642097, "fib": 1},
{"sys__id": 4, "sys__rand": 508807229144041274, "fib": 2},
{"sys__id": 5, "sys__rand": 8730460072520445744, "fib": 3},
{"sys__id": 6, "sys__rand": 154987448000528066, "fib": 5},
{"sys__id": 7, "sys__rand": 4042600467130455702, "fib": 8},
{"sys__id": 8, "sys__rand": 7213364538346925057, "fib": 13},
{"sys__id": 9, "sys__rand": 2695061131372526602, "fib": 21},
{"sys__id": 10, "sys__rand": 2779685447872158540, "fib": 34},
{"sys__id": 7, "sys__rand": 6310705427500864020, "fib": 8},
{"sys__id": 8, "sys__rand": 2154127460471345108, "fib": 13},
{"sys__id": 9, "sys__rand": 2584481985215516118, "fib": 21},
{"sys__id": 10, "sys__rand": 5771949255753972681, "fib": 34},
],
session=test_session,
schema={"sys": Sys, "fib": int},
Expand Down
22 changes: 11 additions & 11 deletions tests/func/test_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
[None, [1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[None, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[0, [1, 1], [[5, 6, 8], [1, 2, 3, 4, 7, 9, 10]]],
[1, [1, 1], [[1, 2, 3, 4, 6, 7, 8, 10], [5, 9]]],
[1234567890, [1, 1], [[2, 3, 5], [1, 4, 6, 7, 8, 9, 10]]],
[0, [1, 1], [[1, 3, 5, 9, 10], [2, 4, 6, 7, 8]]],
[1, [1, 1], [[1, 4, 6, 7, 8, 9], [2, 3, 5, 10]]],
[1234567890, [1, 1], [[2, 3, 4, 5, 6, 7, 8, 9, 10], [1]]],
],
)
def test_train_test_split_not_random(not_random_ds, seed, weights, expected):
Expand All @@ -25,15 +25,15 @@ def test_train_test_split_not_random(not_random_ds, seed, weights, expected):
@pytest.mark.parametrize(
"seed,weights,expected",
[
[None, [1, 1], [[4, 5, 7, 8], [1, 2, 3, 6, 9, 10]]],
[None, [4, 1], [[1, 2, 4, 5, 7, 8, 10], [3, 6, 9]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 4, 5, 7, 8, 10], [3, 6, 9], []]],
[0, [1, 1], [[4, 7, 8, 10], [1, 2, 3, 5, 6, 9]]],
[1, [1, 1], [[3, 4, 6, 7, 10], [1, 2, 5, 8, 9]]],
[1234567890, [1, 1], [[1, 2, 3, 4, 5, 6, 8], [7, 9, 10]]],
[0, [4, 1], [[1, 3, 4, 5, 7, 8, 10], [2, 6, 9]]],
[None, [1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]],
[None, [4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]],
[0, [1, 1], [[1, 3, 4, 8, 10], [2, 5, 6, 7, 9]]],
[1, [1, 1], [[4, 7], [1, 2, 3, 5, 6, 8, 9, 10]]],
[1234567890, [1, 1], [[6, 7, 8, 10], [1, 2, 3, 4, 5, 9]]],
[0, [4, 1], [[1, 2, 3, 4, 6, 7, 8, 9, 10], [5]]],
[1, [4, 1], [[2, 3, 4, 5, 6, 7, 8, 9, 10], [1]]],
[1234567890, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], []]],
[1234567890, [4, 1], [[1, 2, 4, 5, 6, 7, 8, 9, 10], [3]]],
],
)
def test_train_test_split_random(pseudo_random_ds, seed, weights, expected):
Expand Down

0 comments on commit bc1ea48

Please sign in to comment.