diff --git a/src/datachain/toolkit/split.py b/src/datachain/toolkit/split.py index 426c24950..6699005d2 100644 --- a/src/datachain/toolkit/split.py +++ b/src/datachain/toolkit/split.py @@ -1,7 +1,16 @@ +import random +from typing import Optional + from datachain import C, DataChain +RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer. + -def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]: +def train_test_split( + dc: DataChain, + weights: list[float], + seed: Optional[int] = None, +) -> list[DataChain]: """ Splits a DataChain into multiple subsets based on the provided weights. @@ -18,6 +27,8 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]: For example: - `[0.7, 0.3]` corresponds to a 70/30 split; - `[2, 1, 1]` corresponds to a 50/25/25 split. + seed (int, optional): + The seed for the random number generator. Defaults to None. Returns: list[DataChain]: @@ -58,14 +69,16 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]: weights_normalized = [weight / sum(weights) for weight in weights] - resolution = 2**31 - 1 # Maximum positive value for a 32-bit signed integer. + rand_col = C("sys.rand") + if seed is not None: + uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311 + rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment] + rand_col = rand_col % RESOLUTION # type: ignore[assignment] return [ dc.filter( - C("sys__rand") % resolution - >= round(sum(weights_normalized[:index]) * resolution), - C("sys__rand") % resolution - < round(sum(weights_normalized[: index + 1]) * resolution), + rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)), + rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)), ) for index, _ in enumerate(weights_normalized) ] diff --git a/tests/conftest.py b/tests/conftest.py index 45e919178..6a426d26a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -654,18 +654,20 @@ def studio_datasets(requests_mock): @pytest.fixture def not_random_ds(test_session): + # `sys__rand` column is carefully crafted to ensure that `train_test_split` func + # will always return columns in the `sys__id` order if no seed is provided. return DataChain.from_records( [ - {"sys__id": 1, "sys__rand": 200000000, "fib": 0}, - {"sys__id": 2, "sys__rand": 400000000, "fib": 1}, - {"sys__id": 3, "sys__rand": 600000000, "fib": 1}, - {"sys__id": 4, "sys__rand": 800000000, "fib": 2}, - {"sys__id": 5, "sys__rand": 1000000000, "fib": 3}, - {"sys__id": 6, "sys__rand": 1200000000, "fib": 5}, - {"sys__id": 7, "sys__rand": 1400000000, "fib": 8}, - {"sys__id": 8, "sys__rand": 1600000000, "fib": 13}, - {"sys__id": 9, "sys__rand": 1800000000, "fib": 21}, - {"sys__id": 10, "sys__rand": 2000000000, "fib": 34}, + {"sys__id": 1, "sys__rand": 8025184816406567794, "fib": 0}, + {"sys__id": 2, "sys__rand": 8264763963075908010, "fib": 1}, + {"sys__id": 3, "sys__rand": 338514328625642097, "fib": 1}, + {"sys__id": 4, "sys__rand": 508807229144041274, "fib": 2}, + {"sys__id": 5, "sys__rand": 8730460072520445744, "fib": 3}, + {"sys__id": 6, "sys__rand": 154987448000528066, "fib": 5}, + {"sys__id": 7, "sys__rand": 6310705427500864020, "fib": 8}, + {"sys__id": 8, "sys__rand": 2154127460471345108, "fib": 13}, + {"sys__id": 9, "sys__rand": 2584481985215516118, "fib": 21}, + {"sys__id": 10, "sys__rand": 5771949255753972681, "fib": 34}, ], session=test_session, schema={"sys": Sys, "fib": int}, diff --git a/tests/func/test_toolkit.py b/tests/func/test_toolkit.py index 673475eb7..06d45afdd 100644 --- a/tests/func/test_toolkit.py +++ b/tests/func/test_toolkit.py @@ -4,15 +4,18 @@ @pytest.mark.parametrize( - "weights,expected", + "seed,weights,expected", [ - [[1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]], - [[4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]], - [[0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]], + [None, [1, 1], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]], + [None, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]], + [None, [0.7, 0.2, 0.1], [[1, 2, 3, 4, 5, 6, 7], [8, 9], [10]]], + [0, [1, 1], [[3, 5], [1, 2, 4, 6, 7, 8, 9, 10]]], + [1, [1, 1], [[1, 3, 4, 6, 9, 10], [2, 5, 7, 8]]], + [1234567890, [1, 1], [[2, 4, 5, 7, 9], [1, 3, 6, 8, 10]]], ], ) -def test_train_test_split_not_random(not_random_ds, weights, expected): - res = train_test_split(not_random_ds, weights) +def test_train_test_split_not_random(not_random_ds, seed, weights, expected): + res = train_test_split(not_random_ds, weights, seed=seed) assert len(res) == len(expected) for i, dc in enumerate(res): @@ -20,15 +23,21 @@ def test_train_test_split_not_random(not_random_ds, weights, expected): @pytest.mark.parametrize( - "weights,expected", + "seed,weights,expected", [ - [[1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]], - [[4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]], - [[0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]], + [None, [1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]], + [None, [4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]], + [None, [0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]], + [0, [1, 1], [[2, 5, 9, 10], [1, 3, 4, 6, 7, 8]]], + [1, [1, 1], [[1, 2, 3, 4, 5, 6, 8], [7, 9, 10]]], + [1234567890, [1, 1], [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]], + [0, [4, 1], [[1, 2, 4, 5, 7, 9, 10], [3, 6, 8]]], + [1, [4, 1], [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10]]], + [1234567890, [4, 1], [[1, 3, 5, 6, 7, 9, 10], [2, 4, 8]]], ], ) -def test_train_test_split_random(pseudo_random_ds, weights, expected): - res = train_test_split(pseudo_random_ds, weights) +def test_train_test_split_random(pseudo_random_ds, seed, weights, expected): + res = train_test_split(pseudo_random_ds, weights, seed=seed) assert len(res) == len(expected) for i, dc in enumerate(res):