@@ -232,15 +232,25 @@ def __init__(
232232 self .y , self .X = np .asarray (y ), np .asarray (X )
233233 self .outcome_variable_name = y .design_info .column_names [0 ]
234234
235+ # Input validation ----------------------------------------------------
236+ # Check that `treated` appears in the module formula
235237 assert (
236238 "treated" in formula
237239 ), "A predictor column called `treated` should be in the provided dataframe"
240+ # Check that we have `treated` in the incoming dataframe
241+ assert (
242+ "treated" in self .data .columns
243+ ), "Require a boolean column labelling observations which are `treated`"
244+ # Check for `unit` in the incoming dataframe. *This is only used for plotting purposes*
245+ assert (
246+ "unit" in self .data .columns
247+ ), "Require a `unit` column to label unique units. This is used for plotting purposes"
248+ # Check that `group_variable_name` has TWO levels, representing the treated/untreated. But it does not matter what the actual names of the levels are.
249+ assert (
250+ len (pd .Categorical (self .data [self .group_variable_name ]).categories ) is 2
251+ ), f"There must be 2 levels of the grouping variable { self .group_variable_name } . I.e. the treated and untreated."
238252
239- # TODO: check that data in column self.group_variable_name has TWO levels
240-
241- # TODO: check we have `unit` as a predictor column which is an vector of labels of unique units
242-
243- # TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
253+ # TODO: `treated` is a deterministic function of group and time, so this could be a function rather than supplied data
244254
245255 # DEVIATION FROM SKL EXPERIMENT CODE =============================
246256 # fit the model to the observed (pre-intervention) data
0 commit comments