diff --git a/doubleml/data/base_data.py b/doubleml/data/base_data.py index 93543e8b..fecbded7 100644 --- a/doubleml/data/base_data.py +++ b/doubleml/data/base_data.py @@ -702,8 +702,12 @@ def _set_y_z(self): def _set_attr(col): if col is None: return None - assert_all_finite(self.data.loc[:, col]) - return self.data.loc[:, col] + if isinstance(col, list): + converted_data = self.data.loc[:, col].apply(pd.to_numeric, errors="raise") + else: + converted_data = pd.to_numeric(self.data.loc[:, col], errors="raise") + assert_all_finite(converted_data) + return converted_data self._y = _set_attr(self.y_col) self._z = _set_attr(self.z_cols) @@ -740,7 +744,13 @@ def set_x_d(self, treatment_var): assert_all_finite(self.data.loc[:, self.d_cols], allow_nan=self.force_all_d_finite == "allow-nan") if self.force_all_x_finite: assert_all_finite(self.data.loc[:, xd_list], allow_nan=self.force_all_x_finite == "allow-nan") - self._d = self.data.loc[:, treatment_var] + + treatment_data = self.data.loc[:, treatment_var] + # For panel data, preserve datetime type for treatment variables + if pd.api.types.is_datetime64_any_dtype(treatment_data): + self._d = treatment_data + else: + self._d = pd.to_numeric(treatment_data, errors="raise") self._X = self.data.loc[:, xd_list] def _get_optional_col_sets(self): diff --git a/doubleml/data/tests/test_dml_data.py b/doubleml/data/tests/test_dml_data.py index 9fb72934..542bbb61 100644 --- a/doubleml/data/tests/test_dml_data.py +++ b/doubleml/data/tests/test_dml_data.py @@ -1,3 +1,5 @@ +from decimal import Decimal + import numpy as np import pandas as pd import pytest @@ -661,3 +663,32 @@ def test_property_setter_rollback_on_validation_failure(): dml_data.z_cols = ["y"] # Object should remain unchanged assert dml_data.z_cols == original_z_cols + + +@pytest.mark.ci +def test_dml_data_decimal_to_float_conversion(): + """Test that Decimal type columns are converted to float for y and d.""" + n_obs = 100 + data = { + "y": [Decimal(i * 0.1) for i in range(n_obs)], + "d": [Decimal(i * 0.05) for i in range(n_obs)], + "x": [Decimal(i) for i in range(n_obs)], + "z": [Decimal(i * 2) for i in range(n_obs)], + } + df = pd.DataFrame(data) + + dml_data = DoubleMLData(df, y_col="y", d_cols="d", x_cols="x", z_cols="z") + + assert dml_data.y.dtype == np.float64, f"Expected y to be float64, got {dml_data.y.dtype}" + assert dml_data.d.dtype == np.float64, f"Expected d to be float64, got {dml_data.d.dtype}" + assert dml_data.z.dtype == np.float64, f"Expected z to be float64, got {dml_data.z.dtype}" + # x is not converted to float, so its dtype remains Decimal + assert dml_data.x.dtype == Decimal + + expected_y = np.array([float(Decimal(i * 0.1)) for i in range(n_obs)]) + expected_d = np.array([float(Decimal(i * 0.05)) for i in range(n_obs)]) + expected_z = np.array([float(Decimal(i * 2)) for i in range(n_obs)]).reshape(-1, 1) + + np.testing.assert_array_almost_equal(dml_data.y, expected_y) + np.testing.assert_array_almost_equal(dml_data.d, expected_d) + np.testing.assert_array_almost_equal(dml_data.z, expected_z)