Skip to content

Commit

Permalink
more scalastyle fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed Oct 26, 2023
1 parent 5787e9e commit faa1900
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class SyntheticControlEstimator(override val uid: String)

def this() = this(Identifiable.randomUID("syncon"))

// scalastyle:off method.length
override def fit(dataset: Dataset[_]): DiffInDiffModel = logFit({
val df = dataset
.withColumn(getTreatmentCol, treatment.cast(BooleanType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class SyntheticDiffInDiffEstimator(override val uid: String)

def this() = this(Identifiable.randomUID("syndid"))

// scalastyle:off method.length
override def fit(dataset: Dataset[_]): DiffInDiffModel = logFit({
val df = dataset
.withColumn(getTreatmentCol, treatment.cast(BooleanType))
Expand All @@ -32,9 +33,7 @@ class SyntheticDiffInDiffEstimator(override val uid: String)
val controlDf = df.filter(not(treatment)).cache
val preDf = df.filter(not(postTreatment)).cache
val timeIdx = createIndex(preDf, getTimeCol, TimeIdxCol).cache
timeIdx.show(100, false)
val unitIdx = createIndex(controlDf, getUnitCol, UnitIdxCol).cache
unitIdx.show(100, false)
val size = (unitIdx.count, timeIdx.count)

// indexing
Expand All @@ -49,7 +48,6 @@ class SyntheticDiffInDiffEstimator(override val uid: String)
.localCheckpoint(true)

// fit time weights

val (timeWeights, timeIntercept, lossHistoryTimeWeights) = fitTimeWeights(
handleMissingOutcomes(indexedControlDf, timeIdx.count.toInt), size
)
Expand All @@ -66,7 +64,7 @@ class SyntheticDiffInDiffEstimator(override val uid: String)
// join weights
val Row(t: Long, u: Long) = df.agg(
countDistinct(when(postTreatment, col(getTimeCol))),
countDistinct(when(treatment, col(getUnitCol))),
countDistinct(when(treatment, col(getUnitCol)))
).head

val indexedDf = df.join(timeIdx, df(getTimeCol) === timeIdx(getTimeCol), "left_outer")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ trait SyntheticEstimator extends SynapseMLLogging {
private[causal] val weightsCol = "weights"
private[causal] val epsilon = 1E-10

private def solveCLS(A: DMatrix, b: DVector, lambda: Double, fitIntercept: Boolean, size: (Long, Long)): (DVector, Double, Seq[Double]) = {
private def solveCLS(A: DMatrix, b: DVector, lambda: Double, fitIntercept: Boolean, size: (Long, Long))
: (DVector, Double, Seq[Double]) = {
if (size._1 * size._2 <= getLocalSolverThreshold) {
// If matrix size is less than LocalSolverThreshold (defaults to 1M),
// collect the data on the driver node and solve it locally, where matrix-vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package com.microsoft.azure.synapse.ml.causal.linalg
import breeze.linalg.{norm, DenseVector => BDV, max => bmax, sum => bsum}
import breeze.numerics.{abs => babs, exp => bexp}
import breeze.stats.{mean => bmean}
import org.apache.spark.sql.functions.{coalesce, col, lit, abs => sabs, exp => sexp, max => smax, mean => smean, sum => ssum}
import org.apache.spark.sql.functions.{
coalesce, col, lit, abs => sabs, exp => sexp, max => smax, mean => smean, sum => ssum
}
import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession}

case class VectorEntry(i: Long, value: Double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

package com.microsoft.azure.synapse.ml.causal.opt

// scalastyle:off non.ascii.character.disallowed
import breeze.optimize.DiffFunction
import com.microsoft.azure.synapse.ml.causal.CacheOps
import com.microsoft.azure.synapse.ml.causal.linalg.{MatrixOps, VectorOps}

// scalastyle:off non.ascii.character.disallowed
/**
* Solver for the following constrained least square problem:
* minimize ||Ax-b||^2^ + λ||x||^2^, s.t. 1^T^x = 1, 0 ≤ x ≤ 1
Expand All @@ -16,7 +16,6 @@ import com.microsoft.azure.synapse.ml.causal.linalg.{MatrixOps, VectorOps}
* @param numIterNoChange max number of iteration without change in loss function allowed before termination.
* @param tol tolerance for loss function
*/
// scalastyle:on
private[causal] class ConstrainedLeastSquare[TMat, TVec](step: Double,
maxIter: Int,
numIterNoChange: Option[Int] = None,
Expand Down

0 comments on commit faa1900

Please sign in to comment.