diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index 4300f0c0e1..c214ecbdf1 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -444,7 +444,7 @@ def _check_loss_fn( influence_instance: Union["TracInCPBase", "InfluenceFunctionBase"], loss_fn: Optional[Union[Module, Callable]], loss_fn_name: str, - sample_wise_grads_per_batch: Optional[bool] = None, + sample_wise_grads_per_batch: bool = True, ) -> str: """ This checks whether `loss_fn` satisfies the requirements assumed of all @@ -469,16 +469,13 @@ def _check_loss_fn( # attribute. if hasattr(loss_fn, "reduction"): reduction = loss_fn.reduction # type: ignore - if sample_wise_grads_per_batch is None: + if sample_wise_grads_per_batch: assert reduction in [ "sum", "mean", - ], 'reduction for `loss_fn` must be "sum" or "mean"' - reduction_type = str(reduction) - elif sample_wise_grads_per_batch: - assert reduction in ["sum", "mean"], ( + ], ( 'reduction for `loss_fn` must be "sum" or "mean" when ' - "`sample_wise_grads_per_batch` is True" + "`sample_wise_grads_per_batch` is True (i.e. the default value) " ) reduction_type = str(reduction) else: @@ -490,18 +487,7 @@ def _check_loss_fn( # if we are unable to access the reduction used by `loss_fn`, we warn # the user about the assumptions we are making regarding the reduction # used by `loss_fn` - if sample_wise_grads_per_batch is None: - warnings.warn( - f'Since `{loss_fn_name}` has no "reduction" attribute, the ' - f'implementation assumes that `{loss_fn_name}` is a "reduction" loss ' - "function that reduces the per-example losses by taking their *sum*. " - f"If `{loss_fn_name}` instead reduces the per-example losses by " - f"taking their mean, please set the reduction attribute of " - f'`{loss_fn_name}` to "mean", i.e. ' - f'`{loss_fn_name}.reduction = "mean"`.' - ) - reduction_type = "sum" - elif sample_wise_grads_per_batch: + if sample_wise_grads_per_batch: warnings.warn( f"Since `{loss_fn_name}`` has no 'reduction' attribute, and " "`sample_wise_grads_per_batch` is True, the implementation assumes "