diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py index bf39caa..c139dd9 100644 --- a/metalearners/rlearner.py +++ b/metalearners/rlearner.py @@ -335,6 +335,10 @@ def evaluate( treatment_evaluation = {} tau_hat = self.predict(X=X, is_oos=is_oos, oos_method=oos_method) for treatment_variant in range(1, self.n_variants): + is_treatment = w == treatment_variant + is_control = w == 0 + mask = is_treatment | is_control + propensity_estimates = w_hat[:, treatment_variant] / ( w_hat[:, 0] + w_hat[:, treatment_variant] ) @@ -344,11 +348,11 @@ def evaluate( else tau_hat[:, treatment_variant - 1, 0] ) treatment_evaluation[f"r_loss_{treatment_variant}_vs_0"] = r_loss( - cate_estimates=cate_estimates, - outcome_estimates=y_hat, - propensity_scores=propensity_estimates, - outcomes=y, - treatments=w, + cate_estimates=cate_estimates[mask], + outcome_estimates=y_hat[mask], + propensity_scores=propensity_estimates[mask], + outcomes=y[mask], + treatments=w[mask] == treatment_variant, ) return propensity_evaluation | outcome_evaluation | treatment_evaluation