Skip to content

Commit

Permalink
Added test for stratified sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 6, 2022
1 parent 3da5cb6 commit 2d8c7b8
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tests/integration_tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ludwig.api import LudwigModel
from ludwig.backend import create_ray_backend, LOCAL_BACKEND
from ludwig.backend.ray import get_trainer_kwargs, RayBackend
from ludwig.constants import BALANCE_PERCENTAGE_TOLERANCE, NAME, TRAINER
from ludwig.constants import BALANCE_PERCENTAGE_TOLERANCE, NAME, PREPROCESSING, TRAINER
from ludwig.data.dataframe.dask import DaskEngine
from ludwig.data.preprocessing import balance_data
from ludwig.utils.data_utils import read_parquet
Expand Down Expand Up @@ -148,14 +148,18 @@ def run_test_parquet(
num_cpus=2,
num_gpus=None,
df_engine=None,
preprocessing=None,
):
preprocessing = preprocessing or {}
with ray_start(num_cpus=num_cpus, num_gpus=num_gpus):
config = {
"input_features": input_features,
"output_features": output_features,
"combiner": {"type": "concat", "output_size": 14},
TRAINER: {"epochs": 2, "batch_size": 8},
}
if preprocessing:
config[PREPROCESSING] = preprocessing

backend_config = {**RAY_BACKEND_CONFIG}
if df_engine:
Expand Down Expand Up @@ -261,6 +265,21 @@ def test_ray_split():
)


@pytest.mark.distributed
def test_ray_stratify():
input_features = [
number_feature(normalization="zscore"),
set_feature(),
binary_feature(),
]
output_features = [binary_feature()]
run_test_parquet(
input_features,
output_features,
preprocessing={"stratify": output_features[0][NAME]},
)


@pytest.mark.distributed
def test_ray_timeseries():
input_features = [timeseries_feature()]
Expand Down

0 comments on commit 2d8c7b8

Please sign in to comment.