diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index fb26aa32c8..e383773b1b 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -140,16 +140,9 @@ def __init__( Default: None """ - self.model = model + self.model: Module = model - if isinstance(checkpoints, str): - self.checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*"))) - elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str): - self.checkpoints = AV.sort_files(checkpoints) - else: - self.checkpoints = list(checkpoints) # cast to avoid mypy error - if isinstance(self.checkpoints, List): - assert len(self.checkpoints) > 0, "No checkpoints saved!" + self.checkpoints = checkpoints # type: ignore self.checkpoints_load_func = checkpoints_load_func self.loss_fn = loss_fn @@ -181,6 +174,24 @@ def __init__( "percentage completion of the computation, nor any time estimates." ) + @property + def checkpoints(self) -> List[str]: + return self._checkpoints + + @checkpoints.setter + def checkpoints(self, checkpoints: Union[str, List[str], Iterator]) -> None: + if isinstance(checkpoints, str): + self._checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*"))) + elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str): + self._checkpoints = AV.sort_files(checkpoints) + else: + self._checkpoints = list(checkpoints) # cast to avoid mypy error + + if len(self._checkpoints) <= 0: + raise ValueError( + f"Invalid checkpoints provided for TracIn class: {checkpoints}!" + ) + @abstractmethod def self_influence( self, diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index ccc3bf061f..bf38ac0c09 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -82,7 +82,7 @@ class TracInCPFast(TracInCPBase): def __init__( self, model: Module, - final_fc_layer: Module, + final_fc_layer: Union[Module, str], train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, @@ -183,7 +183,7 @@ def __init__( self.vectorize = vectorize # TODO: restore prior state - self.final_fc_layer = final_fc_layer + self.final_fc_layer = final_fc_layer # type: ignore for param in self.final_fc_layer.parameters(): param.requires_grad = True diff --git a/tests/influence/_core/test_tracin_validation.py b/tests/influence/_core/test_tracin_validation.py index f24e56d7e1..682bff408d 100644 --- a/tests/influence/_core/test_tracin_validation.py +++ b/tests/influence/_core/test_tracin_validation.py @@ -36,7 +36,7 @@ class TestTracinValidator(BaseTest): ) def test_tracin_require_inputs_dataset( self, - reduction, + reduction: str, tracin_constructor: Callable, ) -> None: """ @@ -64,6 +64,10 @@ def test_tracin_require_inputs_dataset( tracin.influence(None, k=None) def test_tracincp_fast_rand_proj_inputs(self) -> None: + """ + This test verifies that TracInCPFast should be initialized + with a valid `final_fc_layer`. + """ with tempfile.TemporaryDirectory() as tmpdir: ( net, @@ -83,3 +87,34 @@ def test_tracincp_fast_rand_proj_inputs(self) -> None: loss_fn=nn.MSELoss(), batch_size=1, ) + + @parameterized.expand( + param_list, + name_func=build_test_name_func(), + ) + def test_tracincp_input_checkpoints( + self, reduction: str, tracin_constructor: Callable + ) -> None: + """ + This test verifies that tracinCP and tracinCPFast + class should be initialized with valid `checkpoints`. + """ + with tempfile.TemporaryDirectory() as invalid_tmpdir: + with tempfile.TemporaryDirectory() as tmpdir: + ( + net, + train_dataset, + test_samples, + test_labels, + ) = get_random_model_and_data(tmpdir, unpack_inputs=False) + + with self.assertRaisesRegex( + ValueError, "Invalid checkpoints provided for TracIn class: " + ): + tracin_constructor( + net, + train_dataset, + invalid_tmpdir, + loss_fn=nn.MSELoss(), + batch_size=1, + )