Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 'seed' for 'train_test_split' (take two) #678

Merged
merged 10 commits into from
Dec 11, 2024
25 changes: 19 additions & 6 deletions src/datachain/toolkit/split.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import random
from typing import Optional

from datachain import C, DataChain

RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved


def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
def train_test_split(
dc: DataChain,
weights: list[float],
seed: Optional[int] = None,
) -> list[DataChain]:
"""
Splits a DataChain into multiple subsets based on the provided weights.

Expand All @@ -18,6 +27,8 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
For example:
- `[0.7, 0.3]` corresponds to a 70/30 split;
- `[2, 1, 1]` corresponds to a 50/25/25 split.
seed (int, optional):
The seed for the random number generator. Defaults to None.

Returns:
list[DataChain]:
Expand Down Expand Up @@ -58,14 +69,16 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:

weights_normalized = [weight / sum(weights) for weight in weights]

resolution = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
rand_col = C("sys.rand")
if seed is not None:
uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved
rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved
rand_col = rand_col % RESOLUTION # type: ignore[assignment]
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved

return [
dc.filter(
C("sys__rand") % resolution
>= round(sum(weights_normalized[:index]) * resolution),
C("sys__rand") % resolution
< round(sum(weights_normalized[: index + 1]) * resolution),
rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)),
rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)),
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved
)
for index, _ in enumerate(weights_normalized)
]
22 changes: 12 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,18 +654,20 @@ def studio_datasets(requests_mock):

@pytest.fixture
def not_random_ds(test_session):
# `sys__rand` column is carefully crafted to ensure that `train_test_split` func
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved
# will always return columns in the `sys__id` order if no seed is provided.
return DataChain.from_records(
[
{"sys__id": 1, "sys__rand": 200000000, "fib": 0},
{"sys__id": 2, "sys__rand": 400000000, "fib": 1},
{"sys__id": 3, "sys__rand": 600000000, "fib": 1},
{"sys__id": 4, "sys__rand": 800000000, "fib": 2},
{"sys__id": 5, "sys__rand": 1000000000, "fib": 3},
{"sys__id": 6, "sys__rand": 1200000000, "fib": 5},
{"sys__id": 7, "sys__rand": 1400000000, "fib": 8},
{"sys__id": 8, "sys__rand": 1600000000, "fib": 13},
{"sys__id": 9, "sys__rand": 1800000000, "fib": 21},
{"sys__id": 10, "sys__rand": 2000000000, "fib": 34},
{"sys__id": 1, "sys__rand": 8025184816406567794, "fib": 0},
{"sys__id": 2, "sys__rand": 8264763963075908010, "fib": 1},
{"sys__id": 3, "sys__rand": 338514328625642097, "fib": 1},
{"sys__id": 4, "sys__rand": 508807229144041274, "fib": 2},
{"sys__id": 5, "sys__rand": 8730460072520445744, "fib": 3},
{"sys__id": 6, "sys__rand": 154987448000528066, "fib": 5},
{"sys__id": 7, "sys__rand": 6310705427500864020, "fib": 8},
{"sys__id": 8, "sys__rand": 2154127460471345108, "fib": 13},
{"sys__id": 9, "sys__rand": 2584481985215516118, "fib": 21},
{"sys__id": 10, "sys__rand": 5771949255753972681, "fib": 34},
],
session=test_session,
schema={"sys": Sys, "fib": int},
Expand Down
33 changes: 21 additions & 12 deletions tests/func/test_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,40 @@


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[None, [1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[None, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[0, [1, 1], [[3, 5], [1, 2, 4, 6, 7, 8, 9, 10]]],
[1, [1, 1], [[1, 3, 4, 6, 9, 10], [2, 5, 7, 8]]],
[1234567890, [1, 1], [[2, 4, 5, 7, 9], [1, 3, 6, 8, 10]]],
],
)
def test_train_test_split_not_random(not_random_ds, weights, expected):
res = train_test_split(not_random_ds, weights)
def test_train_test_split_not_random(not_random_ds, seed, weights, expected):
res = train_test_split(not_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
assert list(dc.collect("sys.id")) == expected[i]


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]],
[[4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]],
[[0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]],
[None, [1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]],
[None, [4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]],
[0, [1, 1], [[2, 5, 9, 10], [1, 3, 4, 6, 7, 8]]],
[1, [1, 1], [[1, 2, 3, 4, 5, 6, 8], [7, 9, 10]]],
[1234567890, [1, 1], [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]],
[0, [4, 1], [[1, 2, 4, 5, 7, 9, 10], [3, 6, 8]]],
[1, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[1234567890, [4, 1], [[1, 3, 5, 6, 7, 9, 10], [2, 4, 8]]],
],
)
def test_train_test_split_random(pseudo_random_ds, weights, expected):
res = train_test_split(pseudo_random_ds, weights)
def test_train_test_split_random(pseudo_random_ds, seed, weights, expected):
res = train_test_split(pseudo_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
Expand Down
Loading