Skip to content

Commit e33ce25

Browse files
committed
#76 add input validation
1 parent b1310a6 commit e33ce25

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

causalpy/pymc_experiments.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,15 +232,25 @@ def __init__(
232232
self.y, self.X = np.asarray(y), np.asarray(X)
233233
self.outcome_variable_name = y.design_info.column_names[0]
234234

235+
# Input validation ----------------------------------------------------
236+
# Check that `treated` appears in the module formula
235237
assert (
236238
"treated" in formula
237239
), "A predictor column called `treated` should be in the provided dataframe"
240+
# Check that we have `treated` in the incoming dataframe
241+
assert (
242+
"treated" in self.data.columns
243+
), "Require a boolean column labelling observations which are `treated`"
244+
# Check for `unit` in the incoming dataframe. *This is only used for plotting purposes*
245+
assert (
246+
"unit" in self.data.columns
247+
), "Require a `unit` column to label unique units. This is used for plotting purposes"
248+
# Check that `group_variable_name` has TWO levels, representing the treated/untreated. But it does not matter what the actual names of the levels are.
249+
assert (
250+
len(pd.Categorical(self.data[self.group_variable_name]).categories) is 2
251+
), f"There must be 2 levels of the grouping variable {self.group_variable_name}. I.e. the treated and untreated."
238252

239-
# TODO: check that data in column self.group_variable_name has TWO levels
240-
241-
# TODO: check we have `unit` as a predictor column which is an vector of labels of unique units
242-
243-
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
253+
# TODO: `treated` is a deterministic function of group and time, so this could be a function rather than supplied data
244254

245255
# DEVIATION FROM SKL EXPERIMENT CODE =============================
246256
# fit the model to the observed (pre-intervention) data

0 commit comments

Comments
 (0)