Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 'seed' for 'train_test_split' + simplify split logic #657

Closed
wants to merge 14 commits into from
Closed
41 changes: 32 additions & 9 deletions src/datachain/toolkit/split.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from datachain import C, DataChain
from typing import Optional

from datachain import DataChain
from datachain.func import bit_and, bit_xor, int_hash_64
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved

def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
MAX_SIGNED_INT64 = (1 << 63) - 1 # Maximum positive value for a 64-bit signed integer.


def train_test_split(
dc: DataChain,
weights: list[float],
seed: Optional[int] = None,
) -> list[DataChain]:
"""
Splits a DataChain into multiple subsets based on the provided weights.

Expand All @@ -18,6 +27,8 @@
For example:
- `[0.7, 0.3]` corresponds to a 70/30 split;
- `[2, 1, 1]` corresponds to a 50/25/25 split.
seed (int, optional):
The seed for the random number generator. Defaults to None.

Returns:
list[DataChain]:
Expand Down Expand Up @@ -56,16 +67,28 @@
if any(weight < 0 for weight in weights):
raise ValueError("Weights should be non-negative")

weights_normalized = [weight / sum(weights) for weight in weights]
def scale(multiplier: float) -> int:
if multiplier < 0:
multiplier = 0.0

Check warning on line 72 in src/datachain/toolkit/split.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/split.py#L72

Added line #L72 was not covered by tests
if multiplier > 1:
multiplier = 1.0
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
numerator, denominator = float(multiplier).as_integer_ratio()
return MAX_SIGNED_INT64 * numerator // denominator

resolution = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
weights_normalized = [weight / sum(weights) for weight in weights]
limits = [
[
scale(sum(weights_normalized[:index])),
scale(sum(weights_normalized[: index + 1])),
]
for index, _ in enumerate(weights_normalized)
]

rand_col = "sys.rand" if seed is None else int_hash_64(bit_xor("sys.rand", seed))
return [
dc.filter(
C("sys__rand") % resolution
>= round(sum(weights_normalized[:index]) * resolution),
C("sys__rand") % resolution
< round(sum(weights_normalized[: index + 1]) * resolution),
bit_and(rand_col, MAX_SIGNED_INT64) >= min_,
bit_and(rand_col, MAX_SIGNED_INT64) < max_,
)
for index, _ in enumerate(weights_normalized)
for min_, max_ in limits
]
20 changes: 10 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,16 +656,16 @@ def studio_datasets(requests_mock):
def not_random_ds(test_session):
return DataChain.from_records(
[
{"sys__id": 1, "sys__rand": 200000000, "fib": 0},
{"sys__id": 2, "sys__rand": 400000000, "fib": 1},
{"sys__id": 3, "sys__rand": 600000000, "fib": 1},
{"sys__id": 4, "sys__rand": 800000000, "fib": 2},
{"sys__id": 5, "sys__rand": 1000000000, "fib": 3},
{"sys__id": 6, "sys__rand": 1200000000, "fib": 5},
{"sys__id": 7, "sys__rand": 1400000000, "fib": 8},
{"sys__id": 8, "sys__rand": 1600000000, "fib": 13},
{"sys__id": 9, "sys__rand": 1800000000, "fib": 21},
{"sys__id": 10, "sys__rand": 2000000000, "fib": 34},
{"sys__id": 1, "sys__rand": 461168601842738816, "fib": 0},
{"sys__id": 2, "sys__rand": 1383505805528216320, "fib": 1},
{"sys__id": 3, "sys__rand": 2305843009213693952, "fib": 1},
{"sys__id": 4, "sys__rand": 3228180212899171328, "fib": 2},
{"sys__id": 5, "sys__rand": 4150517416584649216, "fib": 3},
{"sys__id": 6, "sys__rand": 5072854620270127104, "fib": 5},
{"sys__id": 7, "sys__rand": 5995191823955604480, "fib": 8},
{"sys__id": 8, "sys__rand": 6917529027641081856, "fib": 13},
{"sys__id": 9, "sys__rand": 7839866231326559232, "fib": 21},
{"sys__id": 10, "sys__rand": 8762203435012036608, "fib": 34},
],
session=test_session,
schema={"sys": Sys, "fib": int},
Expand Down
33 changes: 21 additions & 12 deletions tests/func/test_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,40 @@


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[None, [1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]],
[None, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]],
[0, [1, 1], [[1, 2, 3, 7], [4, 5, 6, 8, 9, 10]]],
[1, [1, 1], [[3, 4, 5, 7, 9, 10], [1, 2, 6, 8]]],
[1234567890, [1, 1], [[5, 8, 10], [1, 2, 3, 4, 6, 7, 9]]],
],
)
def test_train_test_split_not_random(not_random_ds, weights, expected):
res = train_test_split(not_random_ds, weights)
def test_train_test_split_not_random(not_random_ds, seed, weights, expected):
res = train_test_split(not_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
assert list(dc.collect("sys.id")) == expected[i]


@pytest.mark.parametrize(
"weights,expected",
"seed,weights,expected",
[
[[1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]],
[[4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]],
[[0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]],
[None, [1, 1], [[1, 2, 4, 6, 8], [3, 5, 7, 9, 10]]],
[None, [4, 1], [[1, 2, 4, 6, 8, 10], [3, 5, 7, 9]]],
[None, [0.7, 0.2, 0.1], [[1, 2, 4, 6, 8, 10], [5, 9], [3, 7]]],
[0, [1, 1], [[1, 6, 8, 9], [2, 3, 4, 5, 7, 10]]],
[1, [1, 1], [[1, 3, 5, 6, 7, 9], [2, 4, 8, 10]]],
[1234567890, [1, 1], [[2, 3, 7, 8, 10], [1, 4, 5, 6, 9]]],
[0, [4, 1], [[1, 6, 7, 8, 9, 10], [2, 3, 4, 5]]],
[1, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 9, 10], [8]]],
[1234567890, [4, 1], [[2, 3, 4, 6, 7, 8, 10], [1, 5, 9]]],
],
)
def test_train_test_split_random(pseudo_random_ds, weights, expected):
res = train_test_split(pseudo_random_ds, weights)
def test_train_test_split_random(pseudo_random_ds, seed, weights, expected):
res = train_test_split(pseudo_random_ds, weights, seed=seed)
assert len(res) == len(expected)

for i, dc in enumerate(res):
Expand Down
Loading