Skip to content

Commit

Permalink
Python style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed Oct 27, 2023
1 parent a7b6f20 commit 67e26cf
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 32 deletions.
16 changes: 12 additions & 4 deletions core/src/main/python/synapse/ml/causal/DiffInDiffModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ def __init__(self, java_obj=None) -> None:
self.standardError = self.summary.standardError()
self.timeIntercept = DiffInDiffModel._unwrapOption(self.summary.timeIntercept())
self.unitIntercept = DiffInDiffModel._unwrapOption(self.summary.unitIntercept())
self.timeWeights = DiffInDiffModel._mapOption(java_obj.getTimeWeights(), lambda x: DataFrame(x, sql_ctx))
self.unitWeights = DiffInDiffModel._mapOption(java_obj.getUnitWeights(), lambda x: DataFrame(x, sql_ctx))
self.lossHistoryTimeWeights = DiffInDiffModel._unwrapOption(self.summary.getLossHistoryTimeWeightsJava())
self.lossHistoryUnitWeights = DiffInDiffModel._unwrapOption(self.summary.getLossHistoryUnitWeightsJava())
self.timeWeights = DiffInDiffModel._mapOption(
java_obj.getTimeWeights(), lambda x: DataFrame(x, sql_ctx)
)
self.unitWeights = DiffInDiffModel._mapOption(
java_obj.getUnitWeights(), lambda x: DataFrame(x, sql_ctx)
)
self.lossHistoryTimeWeights = DiffInDiffModel._unwrapOption(
self.summary.getLossHistoryTimeWeightsJava()
)
self.lossHistoryUnitWeights = DiffInDiffModel._unwrapOption(
self.summary.getLossHistoryUnitWeightsJava()
)
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
"outputs": [],
"source": [
"from pyspark.sql.types import *\n",
"from synapse.ml.causal import DiffInDiffEstimator, SyntheticControlEstimator, SyntheticDiffInDiffEstimator\n",
"from synapse.ml.causal import (\n",
" DiffInDiffEstimator,\n",
" SyntheticControlEstimator,\n",
" SyntheticDiffInDiffEstimator,\n",
")\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import style\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -72,8 +76,12 @@
},
"outputs": [],
"source": [
"df = spark.read.option(\"header\", True).option(\"inferSchema\", True).csv(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/smoking.csv\") \\\n",
" .select(\"state\", \"year\", \"cigsale\", \"california\", \"after_treatment\")\n",
"df = (\n",
" spark.read.option(\"header\", True)\n",
" .option(\"inferSchema\", True)\n",
" .csv(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/smoking.csv\")\n",
" .select(\"state\", \"year\", \"cigsale\", \"california\", \"after_treatment\")\n",
")\n",
"display(df)"
]
},
Expand All @@ -100,7 +108,10 @@
},
"outputs": [],
"source": [
"estimator1 = DiffInDiffEstimator(treatmentCol=\"california\", postTreatmentCol=\"after_treatment\", outcomeCol = \"cigsale\")\n",
"estimator1 = DiffInDiffEstimator(\n",
" treatmentCol=\"california\", postTreatmentCol=\"after_treatment\", outcomeCol=\"cigsale\"\n",
")\n",
"\n",
"model1 = estimator1.fit(df)\n",
"\n",
"print(\"[Diff in Diff] treatment effect: {}\".format(model1.treatmentEffect))\n",
Expand Down Expand Up @@ -138,8 +149,16 @@
"outputs": [],
"source": [
"estimator2 = SyntheticControlEstimator(\n",
" timeCol = \"year\", unitCol = \"state\", treatmentCol = \"california\", postTreatmentCol = \"after_treatment\", outcomeCol = \"cigsale\", \n",
" maxIter = 5000, numIterNoChange = 50, tol = 1E-4, stepSize = 1.0)\n",
" timeCol=\"year\",\n",
" unitCol=\"state\",\n",
" treatmentCol=\"california\",\n",
" postTreatmentCol=\"after_treatment\",\n",
" outcomeCol=\"cigsale\",\n",
" maxIter=5000,\n",
" numIterNoChange=50,\n",
" tol=1e-4,\n",
" stepSize=1.0,\n",
")\n",
"\n",
"model2 = estimator2.fit(df)\n",
"\n",
Expand Down Expand Up @@ -180,9 +199,9 @@
"lossHistory = pd.Series(np.array(model2.lossHistoryUnitWeights))\n",
"\n",
"plt.plot(lossHistory[2000:])\n",
"plt.title('loss history - unit weights')\n",
"plt.xlabel('Iteration')\n",
"plt.ylabel('Loss')\n",
"plt.title(\"loss history - unit weights\")\n",
"plt.xlabel(\"Iteration\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()\n",
"\n",
"print(\"Mimimal loss: {}\".format(lossHistory.min()))"
Expand Down Expand Up @@ -213,14 +232,25 @@
"source": [
"sc_weights = model2.unitWeights.toPandas().set_index(\"state\")\n",
"pdf = df.toPandas()\n",
"sc = pdf.query(\"~california\").pivot(index=\"year\", columns=\"state\", values=\"cigsale\").dot(sc_weights)\n",
"\n",
"sc = (\n",
" pdf.query(\"~california\")\n",
" .pivot(index=\"year\", columns=\"state\", values=\"cigsale\")\n",
" .dot(sc_weights)\n",
")\n",
"plt.plot(sc, label=\"Synthetic Control\")\n",
"plt.plot(sc.index, pdf.query(\"california\")[\"cigsale\"], label=\"California\", color=\"C1\")\n",
"\n",
"plt.title(\"Synthetic Control Estimation\")\n",
"plt.ylabel(\"Cigarette Sales\")\n",
"plt.vlines(x=1988, ymin=40, ymax=140, linestyle=\":\", lw=2, label=\"Proposition 99\", color=\"black\")\n",
"plt.vlines(\n",
" x=1988,\n",
" ymin=40,\n",
" ymax=140,\n",
" linestyle=\":\",\n",
" lw=2,\n",
" label=\"Proposition 99\",\n",
" color=\"black\",\n",
")\n",
"plt.legend()"
]
},
Expand Down Expand Up @@ -248,8 +278,16 @@
"outputs": [],
"source": [
"estimator3 = SyntheticDiffInDiffEstimator(\n",
" timeCol = \"year\", unitCol = \"state\", treatmentCol = \"california\", postTreatmentCol = \"after_treatment\", outcomeCol = \"cigsale\", \n",
" maxIter = 5000, numIterNoChange = 50, tol = 1E-4, stepSize = 1.0)\n",
" timeCol=\"year\",\n",
" unitCol=\"state\",\n",
" treatmentCol=\"california\",\n",
" postTreatmentCol=\"after_treatment\",\n",
" outcomeCol=\"cigsale\",\n",
" maxIter=5000,\n",
" numIterNoChange=50,\n",
" tol=1e-4,\n",
" stepSize=1.0,\n",
")\n",
"\n",
"model3 = estimator3.fit(df)\n",
"\n",
Expand Down Expand Up @@ -290,9 +328,9 @@
"lossHistory = pd.Series(np.array(model3.lossHistoryUnitWeights))\n",
"\n",
"plt.plot(lossHistory[1000:])\n",
"plt.title('loss history - unit weights')\n",
"plt.xlabel('Iteration')\n",
"plt.ylabel('Loss')\n",
"plt.title(\"loss history - unit weights\")\n",
"plt.xlabel(\"Iteration\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()\n",
"\n",
"print(\"Mimimal loss: {}\".format(lossHistory.min()))"
Expand All @@ -317,9 +355,9 @@
"lossHistory = pd.Series(np.array(model3.lossHistoryTimeWeights))\n",
"\n",
"plt.plot(lossHistory[1000:])\n",
"plt.title('loss history - time weights')\n",
"plt.xlabel('Iteration')\n",
"plt.ylabel('Loss')\n",
"plt.title(\"loss history - time weights\")\n",
"plt.xlabel(\"Iteration\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()\n",
"\n",
"print(\"Mimimal loss: {}\".format(lossHistory.min()))"
Expand Down Expand Up @@ -355,8 +393,12 @@
"time_intercept = model3.timeIntercept\n",
"\n",
"pdf = df.toPandas()\n",
"pivot_df_control = pdf.query(\"~california\").pivot(index='year', columns='state', values='cigsale')\n",
"pivot_df_treat = pdf.query(\"california\").pivot(index='year', columns='state', values='cigsale')\n",
"pivot_df_control = pdf.query(\"~california\").pivot(\n",
" index=\"year\", columns=\"state\", values=\"cigsale\"\n",
")\n",
"pivot_df_treat = pdf.query(\"california\").pivot(\n",
" index=\"year\", columns=\"state\", values=\"cigsale\"\n",
")\n",
"sc_did = pivot_df_control.values @ unit_weights.values\n",
"treated_mean = pivot_df_treat.mean(axis=1)"
]
Expand Down Expand Up @@ -420,19 +462,36 @@
}
],
"source": [
"fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15,8), sharex=True, gridspec_kw={'height_ratios': [3, 1]})\n",
"fig, (ax1, ax2) = plt.subplots(\n",
" 2, 1, figsize=(15, 8), sharex=True, gridspec_kw={\"height_ratios\": [3, 1]}\n",
")\n",
"fig.suptitle(\"Synthetic Diff in Diff Estimation\")\n",
"\n",
"ax1.plot(pivot_df_control.mean(axis=1), lw=3, color=\"C1\", ls=\"dashed\", label=\"Control Avg.\")\n",
"ax1.plot(\n",
" pivot_df_control.mean(axis=1), lw=3, color=\"C1\", ls=\"dashed\", label=\"Control Avg.\"\n",
")\n",
"ax1.plot(treated_mean, lw=3, color=\"C0\", label=\"California\")\n",
"ax1.plot(pivot_df_control.index, sc_did, label=\"Synthetic Control (SDID)\", color=\"C1\", alpha=.8)\n",
"ax1.plot(\n",
" pivot_df_control.index,\n",
" sc_did,\n",
" label=\"Synthetic Control (SDID)\",\n",
" color=\"C1\",\n",
" alpha=0.8,\n",
")\n",
"ax1.set_ylabel(\"Cigarette Sales\")\n",
"ax1.vlines(1989, treated_mean.min(), treated_mean.max(), color=\"black\", ls=\"dotted\", label=\"Prop. 99\")\n",
"ax1.vlines(\n",
" 1989,\n",
" treated_mean.min(),\n",
" treated_mean.max(),\n",
" color=\"black\",\n",
" ls=\"dotted\",\n",
" label=\"Prop. 99\",\n",
")\n",
"ax1.legend()\n",
"\n",
"ax2.bar(time_weights.index, time_weights['value'], color='skyblue')\n",
"ax2.bar(time_weights.index, time_weights[\"value\"], color=\"skyblue\")\n",
"ax2.set_ylabel(\"Time Weights\")\n",
"ax2.set_xlabel(\"Time\");\n",
"ax2.set_xlabel(\"Time\")\n",
"ax2.vlines(1989, 0, 1, color=\"black\", ls=\"dotted\")"
]
}
Expand Down

0 comments on commit 67e26cf

Please sign in to comment.