Skip to content

Commit 9eda55d

Browse files
authored
Merge pull request #356 from DoubleML/s-add-smpls-shuffle-test
Add test for shuffled external samples
2 parents 70300a0 + 9488306 commit 9eda55d

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

doubleml/tests/test_set_sample_splitting.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,43 @@ def test_doubleml_set_sample_splitting_invalid_sets():
236236
msg = r"Invalid sample split. Test indices must be in \[0, n_obs\)."
237237
with pytest.raises(ValueError, match=msg):
238238
dml_plr.set_sample_splitting(smpls)
239+
240+
241+
@pytest.mark.ci
242+
def test_doubleml_set_sample_splitting_shuffled_indices():
243+
"""Test that externally provided partitions work with shuffled (unsorted) indices."""
244+
# Create valid 2-fold partition with sorted indices
245+
sorted_smpls = [([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]), ([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])]
246+
247+
# Create the same partition but with shuffled indices
248+
shuffled_smpls = [([4, 1, 0, 3, 2], [8, 5, 9, 6, 7]), ([7, 9, 5, 6, 8], [2, 4, 0, 1, 3])]
249+
250+
# Both should work and produce equivalent results
251+
dml_plr_sorted = DoubleMLPLR(dml_data, ml_l, ml_m, n_folds=2, n_rep=2, draw_sample_splitting=False)
252+
dml_plr_shuffled = DoubleMLPLR(dml_data, ml_l, ml_m, n_folds=2, n_rep=2, draw_sample_splitting=False)
253+
254+
dml_plr_sorted.set_sample_splitting(sorted_smpls)
255+
dml_plr_shuffled.set_sample_splitting(shuffled_smpls)
256+
257+
# Both should have same fold structure
258+
assert dml_plr_sorted.n_folds == 2
259+
assert dml_plr_shuffled.n_folds == 2
260+
assert dml_plr_sorted.n_rep == 1
261+
assert dml_plr_shuffled.n_rep == 1
262+
263+
# Fit both models
264+
dml_plr_sorted.fit(store_predictions=True)
265+
dml_plr_shuffled.fit(store_predictions=True)
266+
267+
# Check if coefficient estimates are identical
268+
np.testing.assert_allclose(dml_plr_sorted.coef, dml_plr_shuffled.coef, rtol=1e-10)
269+
np.testing.assert_allclose(dml_plr_sorted.se, dml_plr_shuffled.se, rtol=1e-10)
270+
271+
sorted_preds_l = dml_plr_sorted.predictions["ml_l"][:, 0, 0] # First rep, first treatment
272+
sorted_preds_m = dml_plr_sorted.predictions["ml_m"][:, 0, 0]
273+
shuffled_preds_l = dml_plr_shuffled.predictions["ml_l"][:, 0, 0]
274+
shuffled_preds_m = dml_plr_shuffled.predictions["ml_m"][:, 0, 0]
275+
276+
# Since predictions are stored by observation index, they should be identical
277+
np.testing.assert_allclose(sorted_preds_l, shuffled_preds_l, rtol=1e-10)
278+
np.testing.assert_allclose(sorted_preds_m, shuffled_preds_m, rtol=1e-10)

0 commit comments

Comments
 (0)