diff --git a/doubleml/tests/test_set_sample_splitting.py b/doubleml/tests/test_set_sample_splitting.py index 0995d831..fa0a4394 100644 --- a/doubleml/tests/test_set_sample_splitting.py +++ b/doubleml/tests/test_set_sample_splitting.py @@ -236,3 +236,43 @@ def test_doubleml_set_sample_splitting_invalid_sets(): msg = r"Invalid sample split. Test indices must be in \[0, n_obs\)." with pytest.raises(ValueError, match=msg): dml_plr.set_sample_splitting(smpls) + + +@pytest.mark.ci +def test_doubleml_set_sample_splitting_shuffled_indices(): + """Test that externally provided partitions work with shuffled (unsorted) indices.""" + # Create valid 2-fold partition with sorted indices + sorted_smpls = [([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]), ([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])] + + # Create the same partition but with shuffled indices + shuffled_smpls = [([4, 1, 0, 3, 2], [8, 5, 9, 6, 7]), ([7, 9, 5, 6, 8], [2, 4, 0, 1, 3])] + + # Both should work and produce equivalent results + dml_plr_sorted = DoubleMLPLR(dml_data, ml_l, ml_m, n_folds=2, n_rep=2, draw_sample_splitting=False) + dml_plr_shuffled = DoubleMLPLR(dml_data, ml_l, ml_m, n_folds=2, n_rep=2, draw_sample_splitting=False) + + dml_plr_sorted.set_sample_splitting(sorted_smpls) + dml_plr_shuffled.set_sample_splitting(shuffled_smpls) + + # Both should have same fold structure + assert dml_plr_sorted.n_folds == 2 + assert dml_plr_shuffled.n_folds == 2 + assert dml_plr_sorted.n_rep == 1 + assert dml_plr_shuffled.n_rep == 1 + + # Fit both models + dml_plr_sorted.fit(store_predictions=True) + dml_plr_shuffled.fit(store_predictions=True) + + # Check if coefficient estimates are identical + np.testing.assert_allclose(dml_plr_sorted.coef, dml_plr_shuffled.coef, rtol=1e-10) + np.testing.assert_allclose(dml_plr_sorted.se, dml_plr_shuffled.se, rtol=1e-10) + + sorted_preds_l = dml_plr_sorted.predictions["ml_l"][:, 0, 0] # First rep, first treatment + sorted_preds_m = dml_plr_sorted.predictions["ml_m"][:, 0, 0] + shuffled_preds_l = dml_plr_shuffled.predictions["ml_l"][:, 0, 0] + shuffled_preds_m = dml_plr_shuffled.predictions["ml_m"][:, 0, 0] + + # Since predictions are stored by observation index, they should be identical + np.testing.assert_allclose(sorted_preds_l, shuffled_preds_l, rtol=1e-10) + np.testing.assert_allclose(sorted_preds_m, shuffled_preds_m, rtol=1e-10)