From fa5c9e29465c388b914761f0d89c3a431eaac7eb Mon Sep 17 00:00:00 2001 From: Keith Battocchi Date: Tue, 17 Jan 2023 10:55:10 -0500 Subject: [PATCH] Refactor DynamicDML to remove incompatible method signatures --- econml/panel/dml/_dml.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/econml/panel/dml/_dml.py b/econml/panel/dml/_dml.py index 32207565b..2cfd1f1bb 100644 --- a/econml/panel/dml/_dml.py +++ b/econml/panel/dml/_dml.py @@ -547,32 +547,29 @@ def _gen_model_t(self): def _gen_model_final(self): return StatsModelsLinearRegression(fit_intercept=False) - def _gen_ortho_learner_model_nuisance(self, n_periods): + def _gen_ortho_learner_model_nuisance(self): return _DynamicModelNuisance( model_t=self._gen_model_t(), model_y=self._gen_model_y(), - n_periods=n_periods) + n_periods=self._n_periods) - def _gen_ortho_learner_model_final(self, n_periods): + def _gen_ortho_learner_model_final(self): wrapped_final_model = _DynamicFinalWrapper( StatsModelsLinearRegression(fit_intercept=False), fit_cate_intercept=self.fit_cate_intercept, featurizer=self.featurizer, use_weight_trick=False) - return _LinearDynamicModelFinal(wrapped_final_model, n_periods=n_periods) + return _LinearDynamicModelFinal(wrapped_final_model, n_periods=self._n_periods) def _prefit(self, Y, T, *args, groups=None, only_final=False, **kwargs): + # we need to set the number of periods before calling super()._prefit, since that will generate the + # final and nuisance models, which need to have self._n_periods set u_periods = np.unique(np.unique(groups, return_counts=True)[1]) if len(u_periods) > 1: raise AttributeError( "Imbalanced panel. Method currently expects only panels with equal number of periods. Pad your data") self._n_periods = u_periods[0] - # generate an instance of the final model - self._ortho_learner_model_final = self._gen_ortho_learner_model_final(self._n_periods) - if not only_final: - # generate an instance of the nuisance model - self._ortho_learner_model_nuisance = self._gen_ortho_learner_model_nuisance(self._n_periods) - TreatmentExpansionMixin._prefit(self, Y, T, *args, **kwargs) + super()._prefit(Y, T, *args, **kwargs) def _postfit(self, Y, T, *args, **kwargs): super()._postfit(Y, T, *args, **kwargs)