|  | 
| 4 | 4 | import warnings | 
| 5 | 5 | from abc import abstractmethod | 
| 6 | 6 | from os.path import join | 
| 7 |  | -from typing import ( | 
| 8 |  | -    Any, | 
| 9 |  | -    Callable, | 
| 10 |  | -    Iterator, | 
| 11 |  | -    List, | 
| 12 |  | -    NamedTuple, | 
| 13 |  | -    Optional, | 
| 14 |  | -    Tuple, | 
| 15 |  | -    Type, | 
| 16 |  | -    Union, | 
| 17 |  | -) | 
|  | 7 | +from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union | 
| 18 | 8 | 
 | 
| 19 | 9 | import torch | 
| 20 | 10 | from captum._utils.av import AV | 
| 21 |  | -from captum._utils.common import _get_module_from_name, _parse_version | 
| 22 |  | -from captum._utils.gradient import ( | 
| 23 |  | -    _compute_jacobian_wrt_params, | 
| 24 |  | -    _compute_jacobian_wrt_params_with_sample_wise_trick, | 
| 25 |  | -) | 
|  | 11 | +from captum._utils.common import _parse_version | 
| 26 | 12 | from captum._utils.progress import NullProgress, progress | 
| 27 | 13 | from captum.influence._core.influence import DataInfluence | 
| 28 | 14 | from captum.influence._utils.common import ( | 
| 29 | 15 |     _check_loss_fn, | 
|  | 16 | +    _compute_jacobian_sample_wise_grads_per_batch, | 
| 30 | 17 |     _format_inputs_dataset, | 
| 31 | 18 |     _get_k_most_influential_helper, | 
| 32 | 19 |     _gradient_dot_product, | 
|  | 20 | +    _influence_route_to_helpers, | 
| 33 | 21 |     _load_flexible_state_dict, | 
| 34 | 22 |     _self_influence_by_batches_helper, | 
|  | 23 | +    _set_active_parameters, | 
|  | 24 | +    KMostInfluentialResults, | 
| 35 | 25 | ) | 
| 36 | 26 | from captum.log import log_usage | 
| 37 | 27 | from torch import Tensor | 
|  | 
| 69 | 59 | """ | 
| 70 | 60 | 
 | 
| 71 | 61 | 
 | 
| 72 |  | -class KMostInfluentialResults(NamedTuple): | 
| 73 |  | -    """ | 
| 74 |  | -    This namedtuple stores the results of using the `influence` method. This method | 
| 75 |  | -    is implemented by all subclasses of `TracInCPBase` to calculate | 
| 76 |  | -    proponents / opponents. The `indices` field stores the indices of the | 
| 77 |  | -    proponents / opponents for each example in the test dataset. For example, if | 
| 78 |  | -    finding opponents, `indices[i][j]` stores the index in the training data of the | 
| 79 |  | -    example with the `j`-th highest influence score on the `i`-th example in the test | 
| 80 |  | -    dataset. Similarly, the `influence_scores` field stores the actual influence | 
| 81 |  | -    scores, so that `influence_scores[i][j]` is the influence score of example | 
| 82 |  | -    `indices[i][j]` in the training data on example `i` of the test dataset. | 
| 83 |  | -    Please see `TracInCPBase.influence` for more details. | 
| 84 |  | -    """ | 
| 85 |  | - | 
| 86 |  | -    indices: Tensor | 
| 87 |  | -    influence_scores: Tensor | 
| 88 |  | - | 
| 89 |  | - | 
| 90 | 62 | class TracInCPBase(DataInfluence): | 
| 91 | 63 |     """ | 
| 92 | 64 |     To implement the `influence` method, classes inheriting from `TracInCPBase` will | 
| @@ -448,34 +420,6 @@ def get_name(cls: Type["TracInCPBase"]) -> str: | 
| 448 | 420 |         return cls.__name__ | 
| 449 | 421 | 
 | 
| 450 | 422 | 
 | 
| 451 |  | -def _influence_route_to_helpers( | 
| 452 |  | -    influence_instance: TracInCPBase, | 
| 453 |  | -    inputs: Union[Tuple[Any, ...], DataLoader], | 
| 454 |  | -    k: Optional[int] = None, | 
| 455 |  | -    proponents: bool = True, | 
| 456 |  | -    **kwargs, | 
| 457 |  | -) -> Union[Tensor, KMostInfluentialResults]: | 
| 458 |  | -    """ | 
| 459 |  | -    This is a helper function called by `TracInCP.influence` and | 
| 460 |  | -    `TracInCPFast.influence`. Those methods share a common logic in that they assume | 
| 461 |  | -    an instance of their respective classes implement 2 private methods | 
| 462 |  | -    (``_influence`, `_get_k_most_influential`), and the logic of | 
| 463 |  | -    which private method to call is common, as described in the documentation of the | 
| 464 |  | -    `influence` method. The arguments and return values of this function are the exact | 
| 465 |  | -    same as the `influence` method. Note that `influence_instance` refers to the | 
| 466 |  | -    instance for which the `influence` method was called. | 
| 467 |  | -    """ | 
| 468 |  | -    if k is None: | 
| 469 |  | -        return influence_instance._influence(inputs, **kwargs) | 
| 470 |  | -    else: | 
| 471 |  | -        return influence_instance._get_k_most_influential( | 
| 472 |  | -            inputs, | 
| 473 |  | -            k, | 
| 474 |  | -            proponents, | 
| 475 |  | -            **kwargs, | 
| 476 |  | -        ) | 
| 477 |  | - | 
| 478 |  | - | 
| 479 | 423 | class TracInCP(TracInCPBase): | 
| 480 | 424 |     def __init__( | 
| 481 | 425 |         self, | 
| @@ -630,23 +574,7 @@ def __init__( | 
| 630 | 574 |         """ | 
| 631 | 575 |         self.layer_modules = None | 
| 632 | 576 |         if layers is not None: | 
| 633 |  | -            assert isinstance(layers, List), "`layers` should be a list!" | 
| 634 |  | -            assert len(layers) > 0, "`layers` cannot be empty!" | 
| 635 |  | -            assert isinstance( | 
| 636 |  | -                layers[0], str | 
| 637 |  | -            ), "`layers` should contain str layer names." | 
| 638 |  | -            self.layer_modules = [ | 
| 639 |  | -                _get_module_from_name(self.model, layer) for layer in layers | 
| 640 |  | -            ] | 
| 641 |  | -            for layer, layer_module in zip(layers, self.layer_modules): | 
| 642 |  | -                for name, param in layer_module.named_parameters(): | 
| 643 |  | -                    if not param.requires_grad: | 
| 644 |  | -                        warnings.warn( | 
| 645 |  | -                            "Setting required grads for layer: {}, name: {}".format( | 
| 646 |  | -                                ".".join(layer), name | 
| 647 |  | -                            ) | 
| 648 |  | -                        ) | 
| 649 |  | -                        param.requires_grad = True | 
|  | 577 | +            self.layer_modules = _set_active_parameters(model, layers) | 
| 650 | 578 | 
 | 
| 651 | 579 |     @log_usage() | 
| 652 | 580 |     def influence(  # type: ignore[override] | 
| @@ -1463,19 +1391,6 @@ def _basic_computation_tracincp( | 
| 1463 | 1391 |                     argument is only used if `sample_wise_grads_per_batch` was true in | 
| 1464 | 1392 |                     initialization. | 
| 1465 | 1393 |         """ | 
| 1466 |  | -        if self.sample_wise_grads_per_batch: | 
| 1467 |  | -            return _compute_jacobian_wrt_params_with_sample_wise_trick( | 
| 1468 |  | -                self.model, | 
| 1469 |  | -                inputs, | 
| 1470 |  | -                targets, | 
| 1471 |  | -                loss_fn, | 
| 1472 |  | -                reduction_type, | 
| 1473 |  | -                self.layer_modules, | 
| 1474 |  | -            ) | 
| 1475 |  | -        return _compute_jacobian_wrt_params( | 
| 1476 |  | -            self.model, | 
| 1477 |  | -            inputs, | 
| 1478 |  | -            targets, | 
| 1479 |  | -            loss_fn, | 
| 1480 |  | -            self.layer_modules, | 
|  | 1394 | +        return _compute_jacobian_sample_wise_grads_per_batch( | 
|  | 1395 | +            self, inputs, targets, loss_fn, reduction_type | 
| 1481 | 1396 |         ) | 
0 commit comments