44from  typing  import  (
55    Any ,
66    Callable ,
7+     Dict ,
78    Iterable ,
89    List ,
910    NamedTuple ,
@@ -613,7 +614,7 @@ def _influence_batch_intermediate_quantities_influence_function(
613614    influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
614615    test_batch : Tuple [Any , ...],
615616    train_batch : Tuple [Any , ...],
616- ):
617+ )  ->   Tensor :
617618    """ 
618619    computes influence of a test batch on a train batch, for implementations of 
619620    `IntermediateQuantitiesInfluenceFunction` 
@@ -628,7 +629,7 @@ def _influence_helper_intermediate_quantities_influence_function(
628629    influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
629630    inputs_dataset : Union [Tuple [Any , ...], DataLoader ],
630631    show_progress : bool ,
631- ):
632+ )  ->   Tensor :
632633    """ 
633634    Helper function that computes influence scores for implementations of 
634635    `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` 
@@ -666,7 +667,7 @@ def _self_influence_helper_intermediate_quantities_influence_function(
666667    influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
667668    inputs_dataset : Optional [Union [Tuple [Any , ...], DataLoader ]],
668669    show_progress : bool ,
669- ):
670+ )  ->   Tensor :
670671    """ 
671672    Helper function that computes self-influence scores for implementations of 
672673    `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` 
@@ -983,14 +984,14 @@ def _compute_batch_loss_influence_function_base(
983984        raise  Exception 
984985
985986
986- def  _set_attr (obj , names , val ):
987+ def  _set_attr (obj , names , val )  ->   None :
987988    if  len (names ) ==  1 :
988989        setattr (obj , names [0 ], val )
989990    else :
990991        _set_attr (getattr (obj , names [0 ]), names [1 :], val )
991992
992993
993- def  _del_attr (obj , names ):
994+ def  _del_attr (obj , names )  ->   None :
994995    if  len (names ) ==  1 :
995996        delattr (obj , names [0 ])
996997    else :
@@ -1006,7 +1007,7 @@ def _model_make_functional(model, param_names, params):
10061007    return  params 
10071008
10081009
1009- def  _model_reinsert_params (model , param_names , params , register = False ):
1010+ def  _model_reinsert_params (model , param_names , params , register :  bool   =   False )  ->   None :
10101011    for  param_name , param  in  zip (param_names , params ):
10111012        _set_attr (
10121013            model ,
@@ -1024,7 +1025,7 @@ def _custom_functional_call(model, d, features):
10241025    return  out 
10251026
10261027
1027- def  _functional_call (model , d , features ):
1028+ def  _functional_call (model :  Module , d :  Dict [ str ,  Tensor ] , features ):
10281029    """ 
10291030    Makes a call to `model.forward`, which is treated as a function of the parameters 
10301031    in `d`, a dict from parameter name to parameter, instead of as a function of 
0 commit comments