diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala index 6e63c7c2e4..b8eb604630 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala @@ -20,7 +20,7 @@ import org.apache.spark.ml.param.{DoubleArrayParam, ParamMap} import org.apache.spark.ml.param.shared.{HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasWeightCol} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, LongType, StructType} +import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, IntegerType, LongType, StructType} import org.apache.commons.math3.stat.inference.TestUtils import scala.concurrent.Future @@ -77,21 +77,9 @@ class DoubleMLEstimator(override val uid: String) logFit({ require(getMaxIter > 0, "maxIter should be larger than 0!") val treatmentColType = dataset.schema(getTreatmentCol).dataType - require(treatmentColType == DoubleType || treatmentColType == LongType - || treatmentColType == IntegerType || treatmentColType == BooleanType, - s"TreatmentCol must be of type DoubleType, LongType, IntegerType or BooleanType but got $treatmentColType") - - if (treatmentColType != IntegerType && treatmentColType != BooleanType - && getTreatmentType == TreatmentTypes.Binary) { - throw new Exception("TreatmentModel was set to use classifier " + - "but treatment column in dataset isn't integer or boolean type.") - } - - if (treatmentColType != DoubleType && treatmentColType != LongType - && getTreatmentType == TreatmentTypes.Continuous) { - throw new Exception("TreatmentModel was set to use regression " + - "but treatment column in dataset isn't continuous data type.") - } + validateColTypeWithModel(treatmentColType, getTreatmentCol, getTreatmentModel) + val outcomeColType = dataset.schema(getOutcomeCol).dataType + validateColTypeWithModel(outcomeColType, getOutcomeCol, getOutcomeModel) if (get(weightCol).isDefined) { getTreatmentModel match { @@ -256,7 +244,7 @@ class DoubleMLEstimator(override val uid: String) .setFeaturesCol(treatmentResidualVecCol) .setFamily("gaussian") .setLink("identity") - .setFitIntercept(false) + .setFitIntercept(true) val coefficients = Array(residualsDF1, residualsDF2).map(regressor.fit).map(_.coefficients(0)) val ate = coefficients.sum / coefficients.length @@ -273,6 +261,23 @@ class DoubleMLEstimator(override val uid: String) override def transformSchema(schema: StructType): StructType = { DoubleMLEstimator.validateTransformSchema(schema) } + + protected def validateColTypeWithModel(colType: DataType, colName: String, model: Estimator[_]): Unit = { + val modelType = getDoubleMLModelType(model) + colType match { + case IntegerType | BooleanType => + if (modelType == DoubleMLModelTypes.Continuous) + throw new Exception(s"column $colName in dataset is integer or boolean data type " + + s"but you set to use a regression model for it.") + case DoubleType | LongType => + if (modelType == DoubleMLModelTypes.Binary) + throw new Exception(s"column $colName in dataset is double or long data type" + + "but you set to use a classification model for it.") + case _ => + throw new Exception(s"column $colName must be of type DoubleType, LongType, " + + s"IntegerType or BooleanType but got $colType") + } + } } object DoubleMLEstimator extends ComplexParamsReadable[DoubleMLEstimator] { diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLParams.scala index d40cfaeabd..b8016f568a 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLParams.scala @@ -85,22 +85,24 @@ trait DoubleMLParams extends Params */ def setSampleSplitRatio(value: Array[Double]): this.type = set(sampleSplitRatio, value) - private[causal] object TreatmentTypes extends Enumeration { + private[causal] object DoubleMLModelTypes extends Enumeration { type TreatmentType = Value val Binary, Continuous = Value } - private[causal] def getTreatmentType: TreatmentTypes.Value = { - val treatmentType = - getTreatmentModel match { - case _: ProbabilisticClassifier[_, _, _] => - TreatmentTypes.Binary - case _: Regressor[_, _, _] => - TreatmentTypes.Continuous - } - treatmentType + private[causal] def getDoubleMLModelType(model: Any): DoubleMLModelTypes.Value = { + model match { + case _: ProbabilisticClassifier[_, _, _] => + DoubleMLModelTypes.Binary + case _: Regressor[_, _, _] => + DoubleMLModelTypes.Continuous + case _ => + throw new IllegalArgumentException(s"Invalid model type: ${model.getClass.getName}") + } } + + val confidenceLevel = new DoubleParam( this, "confidenceLevel", diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala index b4f6825f04..218573aeaa 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala @@ -92,7 +92,7 @@ class VerifyDoubleMLEstimator extends EstimatorFuzzing[DoubleMLEstimator] { assert(ateLow < ateHigh && ateLow > -130 && ateHigh < 130) } - test("Invalid treatment model will throw exception.") { + test("Mismatch treatment model and treatment column will throw exception.") { assertThrows[Exception] { val ldml = new DoubleMLEstimator() .setTreatmentModel(new LinearRegression()) @@ -101,7 +101,21 @@ class VerifyDoubleMLEstimator extends EstimatorFuzzing[DoubleMLEstimator] { .setOutcomeCol("col2") .setMaxIter(20) - val ldmlModel = ldml.fit(mockDataset) + ldml.fit(mockDataset) + } + } + + test("Mismatch outcome model and outcome column will throw exception.") { + assertThrows[Exception] { + val ldml = new DoubleMLEstimator() + .setTreatmentModel(new LogisticRegression()) + .setTreatmentCol(mockLabelColumn) + .setOutcomeModel(new LinearRegression()) + .setOutcomeCol("col1") + .setMaxIter(5) + + val dmlModel = ldml.fit(mockDataset) + dmlModel.getAvgTreatmentEffect } }