@@ -236,3 +236,43 @@ def test_doubleml_set_sample_splitting_invalid_sets():
236236 msg = r"Invalid sample split. Test indices must be in \[0, n_obs\)."
237237 with pytest .raises (ValueError , match = msg ):
238238 dml_plr .set_sample_splitting (smpls )
239+
240+
241+ @pytest .mark .ci
242+ def test_doubleml_set_sample_splitting_shuffled_indices ():
243+ """Test that externally provided partitions work with shuffled (unsorted) indices."""
244+ # Create valid 2-fold partition with sorted indices
245+ sorted_smpls = [([0 , 1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 , 9 ]), ([5 , 6 , 7 , 8 , 9 ], [0 , 1 , 2 , 3 , 4 ])]
246+
247+ # Create the same partition but with shuffled indices
248+ shuffled_smpls = [([4 , 1 , 0 , 3 , 2 ], [8 , 5 , 9 , 6 , 7 ]), ([7 , 9 , 5 , 6 , 8 ], [2 , 4 , 0 , 1 , 3 ])]
249+
250+ # Both should work and produce equivalent results
251+ dml_plr_sorted = DoubleMLPLR (dml_data , ml_l , ml_m , n_folds = 2 , n_rep = 2 , draw_sample_splitting = False )
252+ dml_plr_shuffled = DoubleMLPLR (dml_data , ml_l , ml_m , n_folds = 2 , n_rep = 2 , draw_sample_splitting = False )
253+
254+ dml_plr_sorted .set_sample_splitting (sorted_smpls )
255+ dml_plr_shuffled .set_sample_splitting (shuffled_smpls )
256+
257+ # Both should have same fold structure
258+ assert dml_plr_sorted .n_folds == 2
259+ assert dml_plr_shuffled .n_folds == 2
260+ assert dml_plr_sorted .n_rep == 1
261+ assert dml_plr_shuffled .n_rep == 1
262+
263+ # Fit both models
264+ dml_plr_sorted .fit (store_predictions = True )
265+ dml_plr_shuffled .fit (store_predictions = True )
266+
267+ # Check if coefficient estimates are identical
268+ np .testing .assert_allclose (dml_plr_sorted .coef , dml_plr_shuffled .coef , rtol = 1e-10 )
269+ np .testing .assert_allclose (dml_plr_sorted .se , dml_plr_shuffled .se , rtol = 1e-10 )
270+
271+ sorted_preds_l = dml_plr_sorted .predictions ["ml_l" ][:, 0 , 0 ] # First rep, first treatment
272+ sorted_preds_m = dml_plr_sorted .predictions ["ml_m" ][:, 0 , 0 ]
273+ shuffled_preds_l = dml_plr_shuffled .predictions ["ml_l" ][:, 0 , 0 ]
274+ shuffled_preds_m = dml_plr_shuffled .predictions ["ml_m" ][:, 0 , 0 ]
275+
276+ # Since predictions are stored by observation index, they should be identical
277+ np .testing .assert_allclose (sorted_preds_l , shuffled_preds_l , rtol = 1e-10 )
278+ np .testing .assert_allclose (sorted_preds_m , shuffled_preds_m , rtol = 1e-10 )
0 commit comments