diff --git a/tests/integration_tests/test_ray.py b/tests/integration_tests/test_ray.py index 5e8efe5e6b3..0d87660e25f 100644 --- a/tests/integration_tests/test_ray.py +++ b/tests/integration_tests/test_ray.py @@ -26,7 +26,7 @@ from ludwig.api import LudwigModel from ludwig.backend import create_ray_backend, LOCAL_BACKEND from ludwig.backend.ray import get_trainer_kwargs, RayBackend -from ludwig.constants import BALANCE_PERCENTAGE_TOLERANCE, NAME, TRAINER +from ludwig.constants import BALANCE_PERCENTAGE_TOLERANCE, NAME, PREPROCESSING, TRAINER from ludwig.data.dataframe.dask import DaskEngine from ludwig.data.preprocessing import balance_data from ludwig.utils.data_utils import read_parquet @@ -148,7 +148,9 @@ def run_test_parquet( num_cpus=2, num_gpus=None, df_engine=None, + preprocessing=None, ): + preprocessing = preprocessing or {} with ray_start(num_cpus=num_cpus, num_gpus=num_gpus): config = { "input_features": input_features, @@ -156,6 +158,8 @@ def run_test_parquet( "combiner": {"type": "concat", "output_size": 14}, TRAINER: {"epochs": 2, "batch_size": 8}, } + if preprocessing: + config[PREPROCESSING] = preprocessing backend_config = {**RAY_BACKEND_CONFIG} if df_engine: @@ -261,6 +265,21 @@ def test_ray_split(): ) +@pytest.mark.distributed +def test_ray_stratify(): + input_features = [ + number_feature(normalization="zscore"), + set_feature(), + binary_feature(), + ] + output_features = [binary_feature()] + run_test_parquet( + input_features, + output_features, + preprocessing={"stratify": output_features[0][NAME]}, + ) + + @pytest.mark.distributed def test_ray_timeseries(): input_features = [timeseries_feature()]