Skip to content

Commit

Permalink
Implement 'seed' for 'train_test_split' (take two) (#678)
Browse files Browse the repository at this point in the history
Co-authored-by: Helio Machado <0x2b3bfa0+git@googlemail.com>
  • Loading branch information
dreadatour and 0x2b3bfa0 authored Dec 11, 2024
1 parent 6ca3c98 commit 5db33a2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 28 deletions.
25 changes: 19 additions & 6 deletions src/datachain/toolkit/split.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import random
from typing import Optional

from datachain import C, DataChain

RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.


def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
def train_test_split(
dc: DataChain,
weights: list[float],
seed: Optional[int] = None,
) -> list[DataChain]:
"""
Splits a DataChain into multiple subsets based on the provided weights.
Expand All @@ -18,6 +27,8 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
For example:
- `[0.7, 0.3]` corresponds to a 70/30 split;
- `[2, 1, 1]` corresponds to a 50/25/25 split.
seed (int, optional):
The seed for the random number generator. Defaults to None.
Returns:
list[DataChain]:
Expand Down Expand Up @@ -58,14 +69,16 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:

weights_normalized = [weight / sum(weights) for weight in weights]

resolution = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
rand_col = C("sys.rand")
if seed is not None:
uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
rand_col = rand_col % RESOLUTION # type: ignore[assignment]

return [
dc.filter(
C("sys__rand") % resolution
>= round(sum(weights_normalized[:index]) * resolution),
C("sys__rand") % resolution
< round(sum(weights_normalized[: index + 1]) * resolution),
rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)),
rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)),
)
for index, _ in enumerate(weights_normalized)
]
22 changes: 12 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,18 +654,20 @@ def studio_datasets(requests_mock):

@pytest.fixture
def not_random_ds(test_session):
# `sys__rand` column is carefully crafted to ensure that `train_test_split` func
# will always return columns in the `sys__id` order if no seed is provided.
return DataChain.from_records(
[
{"sys__id": 1, "sys__rand": 200000000, "fib": 0},
{"sys__id": 2, "sys__rand": 400000000, "fib": 1},
{"sys__id": 3, "sys__rand": 600000000, "fib": 1},
{"sys__id": 4, "sys__rand": 800000000, "fib": 2},
{"sys__id": 5, "sys__rand": 1000000000, "fib": 3},
{"sys__id": 6, "sys__rand": 1200000000, "fib": 5},
{"sys__id": 7, "sys__rand": 1400000000, "fib": 8},
{"sys__id": 8, "sys__rand": 1600000000, "fib": 13},
{"sys__id": 9, "sys__rand": 1800000000, "fib": 21},
{"sys__id": 10, "sys__rand": 2000000000, "fib": 34},
{"sys__id": 1, "sys__rand": 8025184816406567794, "fib": 0},
{"sys__id": 2, "sys__rand": 8264763963075908010, "fib": 1},
{"sys__id": 3, "sys__rand": 338514328625642097, "fib": 1},
{"sys__id": 4, "sys__rand": 508807229144041274, "fib": 2},
{"sys__id": 5, "sys__rand": 8730460072520445744, "fib": 3},
{"sys__id": 6, "sys__rand": 154987448000528066, "fib": 5},
{"sys__id": 7, "sys__rand": 6310705427500864020, "fib": 8},
{"sys__id": 8, "sys__rand": 2154127460471345108, "fib": 13},
{"sys__id": 9, "sys__rand": 2584481985215516118, "fib": 21},
{"sys__id": 10, "sys__rand": 5771949255753972681, "fib": 34},
],
session=test_session,
schema={"sys": Sys, "fib": int},
Expand Down
33 changes: 21 additions & 12 deletions tests/func/test_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,40 @@


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[None, [1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[None, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[0, [1, 1], [[3, 5], [1, 2, 4, 6, 7, 8, 9, 10]]],
[1, [1, 1], [[1, 3, 4, 6, 9, 10], [2, 5, 7, 8]]],
[1234567890, [1, 1], [[2, 4, 5, 7, 9], [1, 3, 6, 8, 10]]],
],
)
def test_train_test_split_not_random(not_random_ds, weights, expected):
res = train_test_split(not_random_ds, weights)
def test_train_test_split_not_random(not_random_ds, seed, weights, expected):
res = train_test_split(not_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
assert list(dc.collect("sys.id")) == expected[i]


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]],
[[4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]],
[[0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]],
[None, [1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]],
[None, [4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]],
[0, [1, 1], [[2, 5, 9, 10], [1, 3, 4, 6, 7, 8]]],
[1, [1, 1], [[1, 2, 3, 4, 5, 6, 8], [7, 9, 10]]],
[1234567890, [1, 1], [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]],
[0, [4, 1], [[1, 2, 4, 5, 7, 9, 10], [3, 6, 8]]],
[1, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[1234567890, [4, 1], [[1, 3, 5, 6, 7, 9, 10], [2, 4, 8]]],
],
)
def test_train_test_split_random(pseudo_random_ds, weights, expected):
res = train_test_split(pseudo_random_ds, weights)
def test_train_test_split_random(pseudo_random_ds, seed, weights, expected):
res = train_test_split(pseudo_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
Expand Down

0 comments on commit 5db33a2

Please sign in to comment.