44from typing import (
55 Any ,
66 Callable ,
7+ Dict ,
78 Iterable ,
89 List ,
910 NamedTuple ,
@@ -613,7 +614,7 @@ def _influence_batch_intermediate_quantities_influence_function(
613614 influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
614615 test_batch : Tuple [Any , ...],
615616 train_batch : Tuple [Any , ...],
616- ):
617+ ) -> Tensor :
617618 """
618619 computes influence of a test batch on a train batch, for implementations of
619620 `IntermediateQuantitiesInfluenceFunction`
@@ -628,7 +629,7 @@ def _influence_helper_intermediate_quantities_influence_function(
628629 influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
629630 inputs_dataset : Union [Tuple [Any , ...], DataLoader ],
630631 show_progress : bool ,
631- ):
632+ ) -> Tensor :
632633 """
633634 Helper function that computes influence scores for implementations of
634635 `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities`
@@ -666,7 +667,7 @@ def _self_influence_helper_intermediate_quantities_influence_function(
666667 influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
667668 inputs_dataset : Optional [Union [Tuple [Any , ...], DataLoader ]],
668669 show_progress : bool ,
669- ):
670+ ) -> Tensor :
670671 """
671672 Helper function that computes self-influence scores for implementations of
672673 `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities`
@@ -983,14 +984,14 @@ def _compute_batch_loss_influence_function_base(
983984 raise Exception
984985
985986
986- def _set_attr (obj , names , val ):
987+ def _set_attr (obj , names , val ) -> None :
987988 if len (names ) == 1 :
988989 setattr (obj , names [0 ], val )
989990 else :
990991 _set_attr (getattr (obj , names [0 ]), names [1 :], val )
991992
992993
993- def _del_attr (obj , names ):
994+ def _del_attr (obj , names ) -> None :
994995 if len (names ) == 1 :
995996 delattr (obj , names [0 ])
996997 else :
@@ -1006,7 +1007,7 @@ def _model_make_functional(model, param_names, params):
10061007 return params
10071008
10081009
1009- def _model_reinsert_params (model , param_names , params , register = False ):
1010+ def _model_reinsert_params (model , param_names , params , register : bool = False ) -> None :
10101011 for param_name , param in zip (param_names , params ):
10111012 _set_attr (
10121013 model ,
@@ -1024,7 +1025,7 @@ def _custom_functional_call(model, d, features):
10241025 return out
10251026
10261027
1027- def _functional_call (model , d , features ):
1028+ def _functional_call (model : Module , d : Dict [ str , Tensor ] , features ):
10281029 """
10291030 Makes a call to `model.forward`, which is treated as a function of the parameters
10301031 in `d`, a dict from parameter name to parameter, instead of as a function of
0 commit comments