@@ -851,6 +851,57 @@ def _influence_route_to_helpers(
851851 )
852852
853853
854+ def _parameter_dot (
855+ params_1 : Tuple [Tensor , ...], params_2 : Tuple [Tensor , ...]
856+ ) -> Tensor :
857+ """
858+ returns the dot-product of 2 tensors, represented as tuple of tensors.
859+ """
860+ return torch .Tensor (
861+ sum (
862+ torch .sum (param_1 * param_2 )
863+ for (param_1 , param_2 ) in zip (params_1 , params_2 )
864+ )
865+ )
866+
867+
868+ def _parameter_add (
869+ params_1 : Tuple [Tensor , ...], params_2 : Tuple [Tensor , ...]
870+ ) -> Tuple [Tensor , ...]:
871+ """
872+ returns the sum of 2 tensors, represented as tuple of tensors.
873+ """
874+ return tuple (param_1 + param_2 for (param_1 , param_2 ) in zip (params_1 , params_2 ))
875+
876+
877+ def _parameter_multiply (params : Tuple [Tensor , ...], c : Tensor ) -> Tuple [Tensor , ...]:
878+ """
879+ multiplies all tensors in a tuple of tensors by a given scalar
880+ """
881+ return tuple (param * c for param in params )
882+
883+
884+ def _parameter_to (params : Tuple [Tensor , ...], ** to_kwargs ) -> Tuple [Tensor , ...]:
885+ """
886+ applies the `to` method to all tensors in a tuple of tensors
887+ """
888+ return tuple (param .to (** to_kwargs ) for param in params )
889+
890+
891+ def _parameter_linear_combination (
892+ paramss : List [Tuple [Tensor , ...]], cs : Tensor
893+ ) -> Tuple [Tensor , ...]:
894+ """
895+ scales each parameter (tensor of tuples) in a list by the corresponding scalar in a
896+ 1D tensor of the same length, and sums up the scaled parameters
897+ """
898+ assert len (cs .shape ) == 1
899+ result = _parameter_multiply (paramss [0 ], cs [0 ])
900+ for (params , c ) in zip (paramss [1 :], cs [1 :]):
901+ result = _parameter_add (result , _parameter_multiply (params , c ))
902+ return result
903+
904+
854905def _compute_jacobian_sample_wise_grads_per_batch (
855906 influence_inst : Union ["TracInCP" , "InfluenceFunctionBase" ],
856907 inputs : Tuple [Any , ...],
@@ -1007,7 +1058,9 @@ def _functional_call(model, d, features):
10071058def _dataset_fn (dataloader , batch_fn , reduce_fn , * batch_fn_args , ** batch_fn_kwargs ):
10081059 """
10091060 Applies `batch_fn` to each batch in `dataloader`, reducing the results using
1010- `reduce_fn`. This is useful for computing Hessians over an entire dataloader.
1061+ `reduce_fn`. This is useful for computing Hessians and Hessian-vector
1062+ products over an entire dataloader, and is used by both `NaiveInfluenceFunction`
1063+ and `ArnoldiInfluenceFunction`.
10111064 """
10121065 _dataloader = iter (dataloader )
10131066
0 commit comments