Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 'seed' for 'train_test_split' + simplify split logic #657

Closed
wants to merge 14 commits into from
Closed
38 changes: 29 additions & 9 deletions src/datachain/toolkit/split.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from datachain import C, DataChain
from typing import Optional

from datachain import DataChain
from datachain.func import bit_and, bit_xor, int_hash_64

0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved
def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
MAX_SIGNED_INT64 = (1 << 63) - 1 # Maximum positive value for a 64-bit signed integer.


def train_test_split(
dc: DataChain,
weights: list[float],
seed: Optional[int] = None,
) -> list[DataChain]:
"""
Splits a DataChain into multiple subsets based on the provided weights.

Expand All @@ -18,6 +27,8 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
For example:
- `[0.7, 0.3]` corresponds to a 70/30 split;
- `[2, 1, 1]` corresponds to a 50/25/25 split.
seed (int, optional):
The seed for the random number generator. Defaults to None.

Returns:
list[DataChain]:
Expand Down Expand Up @@ -56,16 +67,25 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
if any(weight < 0 for weight in weights):
raise ValueError("Weights should be non-negative")

weights_normalized = [weight / sum(weights) for weight in weights]
def scale(multiplier: float) -> int:
multiplier = min(max(0.0, multiplier), 1.0)
numerator, denominator = float(multiplier).as_integer_ratio()
return MAX_SIGNED_INT64 * numerator // denominator

resolution = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
weights_normalized = [weight / sum(weights) for weight in weights]
limits = [
[
scale(sum(weights_normalized[:index])),
scale(sum(weights_normalized[: index + 1])),
]
for index, _ in enumerate(weights_normalized)
]

rand_col = "sys.rand" if seed is None else int_hash_64(bit_xor("sys.rand", seed))
return [
dc.filter(
C("sys__rand") % resolution
>= round(sum(weights_normalized[:index]) * resolution),
C("sys__rand") % resolution
< round(sum(weights_normalized[: index + 1]) * resolution),
bit_and(rand_col, MAX_SIGNED_INT64) >= min_,
bit_and(rand_col, MAX_SIGNED_INT64) < max_,
)
for index, _ in enumerate(weights_normalized)
for min_, max_ in limits
]
20 changes: 10 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,16 +656,16 @@ def studio_datasets(requests_mock):
def not_random_ds(test_session):
return DataChain.from_records(
[
{"sys__id": 1, "sys__rand": 200000000, "fib": 0},
{"sys__id": 2, "sys__rand": 400000000, "fib": 1},
{"sys__id": 3, "sys__rand": 600000000, "fib": 1},
{"sys__id": 4, "sys__rand": 800000000, "fib": 2},
{"sys__id": 5, "sys__rand": 1000000000, "fib": 3},
{"sys__id": 6, "sys__rand": 1200000000, "fib": 5},
{"sys__id": 7, "sys__rand": 1400000000, "fib": 8},
{"sys__id": 8, "sys__rand": 1600000000, "fib": 13},
{"sys__id": 9, "sys__rand": 1800000000, "fib": 21},
{"sys__id": 10, "sys__rand": 2000000000, "fib": 34},
{"sys__id": 1, "sys__rand": 461168601842738816, "fib": 0},
{"sys__id": 2, "sys__rand": 1383505805528216320, "fib": 1},
{"sys__id": 3, "sys__rand": 2305843009213693952, "fib": 1},
{"sys__id": 4, "sys__rand": 3228180212899171328, "fib": 2},
{"sys__id": 5, "sys__rand": 4150517416584649216, "fib": 3},
{"sys__id": 6, "sys__rand": 5072854620270127104, "fib": 5},
{"sys__id": 7, "sys__rand": 5995191823955604480, "fib": 8},
{"sys__id": 8, "sys__rand": 6917529027641081856, "fib": 13},
{"sys__id": 9, "sys__rand": 7839866231326559232, "fib": 21},
{"sys__id": 10, "sys__rand": 8762203435012036608, "fib": 34},
],
session=test_session,
schema={"sys": Sys, "fib": int},
Expand Down
33 changes: 21 additions & 12 deletions tests/func/test_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,40 @@


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[None, [1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[None, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[0, [1, 1], [[1, 2, 3, 7], [4, 5, 6, 8, 9, 10]]],
[1, [1, 1], [[3, 4, 5, 7, 9, 10], [1, 2, 6, 8]]],
[1234567890, [1, 1], [[5, 8, 10], [1, 2, 3, 4, 6, 7, 9]]],
],
)
def test_train_test_split_not_random(not_random_ds, weights, expected):
res = train_test_split(not_random_ds, weights)
def test_train_test_split_not_random(not_random_ds, seed, weights, expected):
res = train_test_split(not_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
assert list(dc.collect("sys.id")) == expected[i]


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]],
[[4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]],
[[0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]],
[None, [1, 1], [[1, 2, 4, 6, 8], [3, 5, 7, 9, 10]]],
[None, [4, 1], [[1, 2, 4, 6, 8, 10], [3, 5, 7, 9]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 4, 6, 8, 10], [5, 9], [3, 7]]],
[0, [1, 1], [[1, 6, 8, 9], [2, 3, 4, 5, 7, 10]]],
[1, [1, 1], [[1, 3, 5, 6, 7, 9], [2, 4, 8, 10]]],
[1234567890, [1, 1], [[2, 3, 7, 8, 10], [1, 4, 5, 6, 9]]],
[0, [4, 1], [[1, 6, 7, 8, 9, 10], [2, 3, 4, 5]]],
[1, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 9, 10], [8]]],
[1234567890, [4, 1], [[2, 3, 4, 6, 7, 8, 10], [1, 5, 9]]],
],
)
def test_train_test_split_random(pseudo_random_ds, weights, expected):
res = train_test_split(pseudo_random_ds, weights)
def test_train_test_split_random(pseudo_random_ds, seed, weights, expected):
res = train_test_split(pseudo_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
Expand Down
Loading